1use crate::{
2 assistant_settings::{AssistantDockPosition, AssistantSettings},
3 MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
4 RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
5};
6use anyhow::{anyhow, Result};
7use chrono::{DateTime, Local};
8use collections::{HashMap, HashSet};
9use editor::{
10 display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
11 scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
12 Anchor, Editor, ToOffset,
13};
14use fs::Fs;
15use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
16use gpui::{
17 actions,
18 elements::*,
19 executor::Background,
20 geometry::vector::{vec2f, Vector2F},
21 platform::{CursorStyle, MouseButton},
22 Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
23 Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
24};
25use isahc::{http::StatusCode, Request, RequestExt};
26use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
27use search::BufferSearchBar;
28use serde::Deserialize;
29use settings::SettingsStore;
30use std::{
31 cell::RefCell,
32 cmp, env,
33 fmt::Write,
34 io, iter,
35 ops::Range,
36 path::{Path, PathBuf},
37 rc::Rc,
38 sync::Arc,
39 time::Duration,
40};
41use theme::AssistantStyle;
42use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
43use workspace::{
44 dock::{DockPosition, Panel},
45 searchable::Direction,
46 Save, ToggleZoom, Toolbar, Workspace,
47};
48
49const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
50
51actions!(
52 assistant,
53 [
54 NewConversation,
55 Assist,
56 Split,
57 CycleMessageRole,
58 QuoteSelection,
59 ToggleFocus,
60 ResetKey,
61 ]
62);
63
64pub fn init(cx: &mut AppContext) {
65 settings::register::<AssistantSettings>(cx);
66 cx.add_action(
67 |this: &mut AssistantPanel,
68 _: &workspace::NewFile,
69 cx: &mut ViewContext<AssistantPanel>| {
70 this.new_conversation(cx);
71 },
72 );
73 cx.add_action(ConversationEditor::assist);
74 cx.capture_action(ConversationEditor::cancel_last_assist);
75 cx.capture_action(ConversationEditor::save);
76 cx.add_action(ConversationEditor::quote_selection);
77 cx.capture_action(ConversationEditor::copy);
78 cx.add_action(ConversationEditor::split);
79 cx.capture_action(ConversationEditor::cycle_message_role);
80 cx.add_action(AssistantPanel::save_api_key);
81 cx.add_action(AssistantPanel::reset_api_key);
82 cx.add_action(AssistantPanel::toggle_zoom);
83 cx.add_action(AssistantPanel::deploy);
84 cx.add_action(AssistantPanel::select_next_match);
85 cx.add_action(AssistantPanel::select_prev_match);
86 cx.add_action(AssistantPanel::handle_editor_cancel);
87 cx.add_action(
88 |workspace: &mut Workspace, _: &ToggleFocus, cx: &mut ViewContext<Workspace>| {
89 workspace.toggle_panel_focus::<AssistantPanel>(cx);
90 },
91 );
92}
93
94#[derive(Debug)]
95pub enum AssistantPanelEvent {
96 ZoomIn,
97 ZoomOut,
98 Focus,
99 Close,
100 DockPositionChanged,
101}
102
103pub struct AssistantPanel {
104 workspace: WeakViewHandle<Workspace>,
105 width: Option<f32>,
106 height: Option<f32>,
107 active_editor_index: Option<usize>,
108 prev_active_editor_index: Option<usize>,
109 editors: Vec<ViewHandle<ConversationEditor>>,
110 saved_conversations: Vec<SavedConversationMetadata>,
111 saved_conversations_list_state: UniformListState,
112 zoomed: bool,
113 has_focus: bool,
114 toolbar: ViewHandle<Toolbar>,
115 api_key: Rc<RefCell<Option<String>>>,
116 api_key_editor: Option<ViewHandle<Editor>>,
117 has_read_credentials: bool,
118 languages: Arc<LanguageRegistry>,
119 fs: Arc<dyn Fs>,
120 subscriptions: Vec<Subscription>,
121 _watch_saved_conversations: Task<Result<()>>,
122}
123
124impl AssistantPanel {
125 pub fn load(
126 workspace: WeakViewHandle<Workspace>,
127 cx: AsyncAppContext,
128 ) -> Task<Result<ViewHandle<Self>>> {
129 cx.spawn(|mut cx| async move {
130 let fs = workspace.read_with(&cx, |workspace, _| workspace.app_state().fs.clone())?;
131 let saved_conversations = SavedConversationMetadata::list(fs.clone())
132 .await
133 .log_err()
134 .unwrap_or_default();
135
136 // TODO: deserialize state.
137 let workspace_handle = workspace.clone();
138 workspace.update(&mut cx, |workspace, cx| {
139 cx.add_view::<Self, _>(|cx| {
140 const CONVERSATION_WATCH_DURATION: Duration = Duration::from_millis(100);
141 let _watch_saved_conversations = cx.spawn(move |this, mut cx| async move {
142 let mut events = fs
143 .watch(&CONVERSATIONS_DIR, CONVERSATION_WATCH_DURATION)
144 .await;
145 while events.next().await.is_some() {
146 let saved_conversations = SavedConversationMetadata::list(fs.clone())
147 .await
148 .log_err()
149 .unwrap_or_default();
150 this.update(&mut cx, |this, cx| {
151 this.saved_conversations = saved_conversations;
152 cx.notify();
153 })
154 .ok();
155 }
156
157 anyhow::Ok(())
158 });
159
160 let toolbar = cx.add_view(|cx| {
161 let mut toolbar = Toolbar::new(None);
162 toolbar.set_can_navigate(false, cx);
163 toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
164 toolbar
165 });
166 let mut this = Self {
167 workspace: workspace_handle,
168 active_editor_index: Default::default(),
169 prev_active_editor_index: Default::default(),
170 editors: Default::default(),
171 saved_conversations,
172 saved_conversations_list_state: Default::default(),
173 zoomed: false,
174 has_focus: false,
175 toolbar,
176 api_key: Rc::new(RefCell::new(None)),
177 api_key_editor: None,
178 has_read_credentials: false,
179 languages: workspace.app_state().languages.clone(),
180 fs: workspace.app_state().fs.clone(),
181 width: None,
182 height: None,
183 subscriptions: Default::default(),
184 _watch_saved_conversations,
185 };
186
187 let mut old_dock_position = this.position(cx);
188 this.subscriptions =
189 vec![cx.observe_global::<SettingsStore, _>(move |this, cx| {
190 let new_dock_position = this.position(cx);
191 if new_dock_position != old_dock_position {
192 old_dock_position = new_dock_position;
193 cx.emit(AssistantPanelEvent::DockPositionChanged);
194 }
195 cx.notify();
196 })];
197
198 this
199 })
200 })
201 })
202 }
203
204 fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
205 let editor = cx.add_view(|cx| {
206 ConversationEditor::new(
207 self.api_key.clone(),
208 self.languages.clone(),
209 self.fs.clone(),
210 cx,
211 )
212 });
213 self.add_conversation(editor.clone(), cx);
214 editor
215 }
216
217 fn add_conversation(
218 &mut self,
219 editor: ViewHandle<ConversationEditor>,
220 cx: &mut ViewContext<Self>,
221 ) {
222 self.subscriptions
223 .push(cx.subscribe(&editor, Self::handle_conversation_editor_event));
224
225 let conversation = editor.read(cx).conversation.clone();
226 self.subscriptions
227 .push(cx.observe(&conversation, |_, _, cx| cx.notify()));
228
229 let index = self.editors.len();
230 self.editors.push(editor);
231 self.set_active_editor_index(Some(index), cx);
232 }
233
234 fn set_active_editor_index(&mut self, index: Option<usize>, cx: &mut ViewContext<Self>) {
235 self.prev_active_editor_index = self.active_editor_index;
236 self.active_editor_index = index;
237 if let Some(editor) = self.active_editor() {
238 let editor = editor.read(cx).editor.clone();
239 self.toolbar.update(cx, |toolbar, cx| {
240 toolbar.set_active_item(Some(&editor), cx);
241 });
242 if self.has_focus(cx) {
243 cx.focus(&editor);
244 }
245 } else {
246 self.toolbar.update(cx, |toolbar, cx| {
247 toolbar.set_active_item(None, cx);
248 });
249 }
250
251 cx.notify();
252 }
253
254 fn handle_conversation_editor_event(
255 &mut self,
256 _: ViewHandle<ConversationEditor>,
257 event: &ConversationEditorEvent,
258 cx: &mut ViewContext<Self>,
259 ) {
260 match event {
261 ConversationEditorEvent::TabContentChanged => cx.notify(),
262 }
263 }
264
265 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
266 if let Some(api_key) = self
267 .api_key_editor
268 .as_ref()
269 .map(|editor| editor.read(cx).text(cx))
270 {
271 if !api_key.is_empty() {
272 cx.platform()
273 .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
274 .log_err();
275 *self.api_key.borrow_mut() = Some(api_key);
276 self.api_key_editor.take();
277 cx.focus_self();
278 cx.notify();
279 }
280 } else {
281 cx.propagate_action();
282 }
283 }
284
285 fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
286 cx.platform().delete_credentials(OPENAI_API_URL).log_err();
287 self.api_key.take();
288 self.api_key_editor = Some(build_api_key_editor(cx));
289 cx.focus_self();
290 cx.notify();
291 }
292
293 fn toggle_zoom(&mut self, _: &workspace::ToggleZoom, cx: &mut ViewContext<Self>) {
294 if self.zoomed {
295 cx.emit(AssistantPanelEvent::ZoomOut)
296 } else {
297 cx.emit(AssistantPanelEvent::ZoomIn)
298 }
299 }
300
301 fn deploy(&mut self, action: &search::buffer_search::Deploy, cx: &mut ViewContext<Self>) {
302 let mut propagate_action = true;
303 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
304 search_bar.update(cx, |search_bar, cx| {
305 if search_bar.show(cx) {
306 search_bar.search_suggested(cx);
307 if action.focus {
308 search_bar.select_query(cx);
309 cx.focus_self();
310 }
311 propagate_action = false
312 }
313 });
314 }
315 if propagate_action {
316 cx.propagate_action();
317 }
318 }
319
320 fn handle_editor_cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
321 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
322 if !search_bar.read(cx).is_dismissed() {
323 search_bar.update(cx, |search_bar, cx| {
324 search_bar.dismiss(&Default::default(), cx)
325 });
326 return;
327 }
328 }
329 cx.propagate_action();
330 }
331
332 fn select_next_match(&mut self, _: &search::SelectNextMatch, cx: &mut ViewContext<Self>) {
333 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
334 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Next, 1, cx));
335 }
336 }
337
338 fn select_prev_match(&mut self, _: &search::SelectPrevMatch, cx: &mut ViewContext<Self>) {
339 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
340 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Prev, 1, cx));
341 }
342 }
343
344 fn active_editor(&self) -> Option<&ViewHandle<ConversationEditor>> {
345 self.editors.get(self.active_editor_index?)
346 }
347
348 fn render_hamburger_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
349 enum History {}
350 let theme = theme::current(cx);
351 let tooltip_style = theme::current(cx).tooltip.clone();
352 MouseEventHandler::<History, _>::new(0, cx, |state, _| {
353 let style = theme.assistant.hamburger_button.style_for(state);
354 Svg::for_style(style.icon.clone())
355 .contained()
356 .with_style(style.container)
357 })
358 .with_cursor_style(CursorStyle::PointingHand)
359 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
360 if this.active_editor().is_some() {
361 this.set_active_editor_index(None, cx);
362 } else {
363 this.set_active_editor_index(this.prev_active_editor_index, cx);
364 }
365 })
366 .with_tooltip::<History>(1, "History".into(), None, tooltip_style, cx)
367 }
368
369 fn render_editor_tools(&self, cx: &mut ViewContext<Self>) -> Vec<AnyElement<Self>> {
370 if self.active_editor().is_some() {
371 vec![
372 Self::render_split_button(cx).into_any(),
373 Self::render_quote_button(cx).into_any(),
374 Self::render_assist_button(cx).into_any(),
375 ]
376 } else {
377 Default::default()
378 }
379 }
380
381 fn render_split_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
382 let theme = theme::current(cx);
383 let tooltip_style = theme::current(cx).tooltip.clone();
384 MouseEventHandler::<Split, _>::new(0, cx, |state, _| {
385 let style = theme.assistant.split_button.style_for(state);
386 Svg::for_style(style.icon.clone())
387 .contained()
388 .with_style(style.container)
389 })
390 .with_cursor_style(CursorStyle::PointingHand)
391 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
392 if let Some(active_editor) = this.active_editor() {
393 active_editor.update(cx, |editor, cx| editor.split(&Default::default(), cx));
394 }
395 })
396 .with_tooltip::<Split>(
397 1,
398 "Split Message".into(),
399 Some(Box::new(Split)),
400 tooltip_style,
401 cx,
402 )
403 }
404
405 fn render_assist_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
406 let theme = theme::current(cx);
407 let tooltip_style = theme::current(cx).tooltip.clone();
408 MouseEventHandler::<Assist, _>::new(0, cx, |state, _| {
409 let style = theme.assistant.assist_button.style_for(state);
410 Svg::for_style(style.icon.clone())
411 .contained()
412 .with_style(style.container)
413 })
414 .with_cursor_style(CursorStyle::PointingHand)
415 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
416 if let Some(active_editor) = this.active_editor() {
417 active_editor.update(cx, |editor, cx| editor.assist(&Default::default(), cx));
418 }
419 })
420 .with_tooltip::<Assist>(
421 1,
422 "Assist".into(),
423 Some(Box::new(Assist)),
424 tooltip_style,
425 cx,
426 )
427 }
428
429 fn render_quote_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
430 let theme = theme::current(cx);
431 let tooltip_style = theme::current(cx).tooltip.clone();
432 MouseEventHandler::<QuoteSelection, _>::new(0, cx, |state, _| {
433 let style = theme.assistant.quote_button.style_for(state);
434 Svg::for_style(style.icon.clone())
435 .contained()
436 .with_style(style.container)
437 })
438 .with_cursor_style(CursorStyle::PointingHand)
439 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
440 if let Some(workspace) = this.workspace.upgrade(cx) {
441 cx.window_context().defer(move |cx| {
442 workspace.update(cx, |workspace, cx| {
443 ConversationEditor::quote_selection(workspace, &Default::default(), cx)
444 });
445 });
446 }
447 })
448 .with_tooltip::<QuoteSelection>(
449 1,
450 "Quote Selection".into(),
451 Some(Box::new(QuoteSelection)),
452 tooltip_style,
453 cx,
454 )
455 }
456
457 fn render_plus_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
458 let theme = theme::current(cx);
459 let tooltip_style = theme::current(cx).tooltip.clone();
460 MouseEventHandler::<NewConversation, _>::new(0, cx, |state, _| {
461 let style = theme.assistant.plus_button.style_for(state);
462 Svg::for_style(style.icon.clone())
463 .contained()
464 .with_style(style.container)
465 })
466 .with_cursor_style(CursorStyle::PointingHand)
467 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
468 this.new_conversation(cx);
469 })
470 .with_tooltip::<NewConversation>(
471 1,
472 "New Conversation".into(),
473 Some(Box::new(NewConversation)),
474 tooltip_style,
475 cx,
476 )
477 }
478
479 fn render_zoom_button(&self, cx: &mut ViewContext<Self>) -> impl Element<Self> {
480 enum ToggleZoomButton {}
481
482 let theme = theme::current(cx);
483 let tooltip_style = theme::current(cx).tooltip.clone();
484 let style = if self.zoomed {
485 &theme.assistant.zoom_out_button
486 } else {
487 &theme.assistant.zoom_in_button
488 };
489
490 MouseEventHandler::<ToggleZoomButton, _>::new(0, cx, |state, _| {
491 let style = style.style_for(state);
492 Svg::for_style(style.icon.clone())
493 .contained()
494 .with_style(style.container)
495 })
496 .with_cursor_style(CursorStyle::PointingHand)
497 .on_click(MouseButton::Left, |_, this, cx| {
498 this.toggle_zoom(&ToggleZoom, cx);
499 })
500 .with_tooltip::<ToggleZoom>(
501 0,
502 if self.zoomed {
503 "Zoom Out".into()
504 } else {
505 "Zoom In".into()
506 },
507 Some(Box::new(ToggleZoom)),
508 tooltip_style,
509 cx,
510 )
511 }
512
513 fn render_saved_conversation(
514 &mut self,
515 index: usize,
516 cx: &mut ViewContext<Self>,
517 ) -> impl Element<Self> {
518 let conversation = &self.saved_conversations[index];
519 let path = conversation.path.clone();
520 MouseEventHandler::<SavedConversationMetadata, _>::new(index, cx, move |state, cx| {
521 let style = &theme::current(cx).assistant.saved_conversation;
522 Flex::row()
523 .with_child(
524 Label::new(
525 conversation.mtime.format("%F %I:%M%p").to_string(),
526 style.saved_at.text.clone(),
527 )
528 .aligned()
529 .contained()
530 .with_style(style.saved_at.container),
531 )
532 .with_child(
533 Label::new(conversation.title.clone(), style.title.text.clone())
534 .aligned()
535 .contained()
536 .with_style(style.title.container),
537 )
538 .contained()
539 .with_style(*style.container.style_for(state))
540 })
541 .with_cursor_style(CursorStyle::PointingHand)
542 .on_click(MouseButton::Left, move |_, this, cx| {
543 this.open_conversation(path.clone(), cx)
544 .detach_and_log_err(cx)
545 })
546 }
547
548 fn open_conversation(&mut self, path: PathBuf, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
549 if let Some(ix) = self.editor_index_for_path(&path, cx) {
550 self.set_active_editor_index(Some(ix), cx);
551 return Task::ready(Ok(()));
552 }
553
554 let fs = self.fs.clone();
555 let api_key = self.api_key.clone();
556 let languages = self.languages.clone();
557 cx.spawn(|this, mut cx| async move {
558 let saved_conversation = fs.load(&path).await?;
559 let saved_conversation = serde_json::from_str(&saved_conversation)?;
560 let conversation = cx.add_model(|cx| {
561 Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
562 });
563 this.update(&mut cx, |this, cx| {
564 // If, by the time we've loaded the conversation, the user has already opened
565 // the same conversation, we don't want to open it again.
566 if let Some(ix) = this.editor_index_for_path(&path, cx) {
567 this.set_active_editor_index(Some(ix), cx);
568 } else {
569 let editor = cx
570 .add_view(|cx| ConversationEditor::for_conversation(conversation, fs, cx));
571 this.add_conversation(editor, cx);
572 }
573 })?;
574 Ok(())
575 })
576 }
577
578 fn editor_index_for_path(&self, path: &Path, cx: &AppContext) -> Option<usize> {
579 self.editors
580 .iter()
581 .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
582 }
583}
584
585fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
586 cx.add_view(|cx| {
587 let mut editor = Editor::single_line(
588 Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
589 cx,
590 );
591 editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
592 editor
593 })
594}
595
596impl Entity for AssistantPanel {
597 type Event = AssistantPanelEvent;
598}
599
600impl View for AssistantPanel {
601 fn ui_name() -> &'static str {
602 "AssistantPanel"
603 }
604
605 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
606 let theme = &theme::current(cx);
607 let style = &theme.assistant;
608 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
609 Flex::column()
610 .with_child(
611 Text::new(
612 "Paste your OpenAI API key and press Enter to use the assistant",
613 style.api_key_prompt.text.clone(),
614 )
615 .aligned(),
616 )
617 .with_child(
618 ChildView::new(api_key_editor, cx)
619 .contained()
620 .with_style(style.api_key_editor.container)
621 .aligned(),
622 )
623 .contained()
624 .with_style(style.api_key_prompt.container)
625 .aligned()
626 .into_any()
627 } else {
628 let title = self.active_editor().map(|editor| {
629 Label::new(editor.read(cx).title(cx), style.title.text.clone())
630 .contained()
631 .with_style(style.title.container)
632 .aligned()
633 .left()
634 .flex(1., false)
635 });
636 let mut header = Flex::row()
637 .with_child(Self::render_hamburger_button(cx).aligned())
638 .with_children(title);
639 if self.has_focus {
640 header.add_children(
641 self.render_editor_tools(cx)
642 .into_iter()
643 .map(|tool| tool.aligned().flex_float()),
644 );
645 header.add_child(Self::render_plus_button(cx).aligned().flex_float());
646 header.add_child(self.render_zoom_button(cx).aligned());
647 }
648
649 Flex::column()
650 .with_child(
651 header
652 .contained()
653 .with_style(theme.workspace.tab_bar.container)
654 .expanded()
655 .constrained()
656 .with_height(theme.workspace.tab_bar.height),
657 )
658 .with_children(if self.toolbar.read(cx).hidden() {
659 None
660 } else {
661 Some(ChildView::new(&self.toolbar, cx).expanded())
662 })
663 .with_child(if let Some(editor) = self.active_editor() {
664 ChildView::new(editor, cx).flex(1., true).into_any()
665 } else {
666 UniformList::new(
667 self.saved_conversations_list_state.clone(),
668 self.saved_conversations.len(),
669 cx,
670 |this, range, items, cx| {
671 for ix in range {
672 items.push(this.render_saved_conversation(ix, cx).into_any());
673 }
674 },
675 )
676 .flex(1., true)
677 .into_any()
678 })
679 .into_any()
680 }
681 }
682
683 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
684 self.has_focus = true;
685 self.toolbar
686 .update(cx, |toolbar, cx| toolbar.focus_changed(true, cx));
687 cx.notify();
688 if cx.is_self_focused() {
689 if let Some(editor) = self.active_editor() {
690 cx.focus(editor);
691 } else if let Some(api_key_editor) = self.api_key_editor.as_ref() {
692 cx.focus(api_key_editor);
693 }
694 }
695 }
696
697 fn focus_out(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
698 self.has_focus = false;
699 self.toolbar
700 .update(cx, |toolbar, cx| toolbar.focus_changed(false, cx));
701 cx.notify();
702 }
703}
704
705impl Panel for AssistantPanel {
706 fn position(&self, cx: &WindowContext) -> DockPosition {
707 match settings::get::<AssistantSettings>(cx).dock {
708 AssistantDockPosition::Left => DockPosition::Left,
709 AssistantDockPosition::Bottom => DockPosition::Bottom,
710 AssistantDockPosition::Right => DockPosition::Right,
711 }
712 }
713
714 fn position_is_valid(&self, _: DockPosition) -> bool {
715 true
716 }
717
718 fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
719 settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
720 let dock = match position {
721 DockPosition::Left => AssistantDockPosition::Left,
722 DockPosition::Bottom => AssistantDockPosition::Bottom,
723 DockPosition::Right => AssistantDockPosition::Right,
724 };
725 settings.dock = Some(dock);
726 });
727 }
728
729 fn size(&self, cx: &WindowContext) -> f32 {
730 let settings = settings::get::<AssistantSettings>(cx);
731 match self.position(cx) {
732 DockPosition::Left | DockPosition::Right => {
733 self.width.unwrap_or_else(|| settings.default_width)
734 }
735 DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
736 }
737 }
738
739 fn set_size(&mut self, size: f32, cx: &mut ViewContext<Self>) {
740 match self.position(cx) {
741 DockPosition::Left | DockPosition::Right => self.width = Some(size),
742 DockPosition::Bottom => self.height = Some(size),
743 }
744 cx.notify();
745 }
746
747 fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
748 matches!(event, AssistantPanelEvent::ZoomIn)
749 }
750
751 fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
752 matches!(event, AssistantPanelEvent::ZoomOut)
753 }
754
755 fn is_zoomed(&self, _: &WindowContext) -> bool {
756 self.zoomed
757 }
758
759 fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
760 self.zoomed = zoomed;
761 cx.notify();
762 }
763
764 fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
765 if active {
766 if self.api_key.borrow().is_none() && !self.has_read_credentials {
767 self.has_read_credentials = true;
768 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
769 Some(api_key)
770 } else if let Some((_, api_key)) = cx
771 .platform()
772 .read_credentials(OPENAI_API_URL)
773 .log_err()
774 .flatten()
775 {
776 String::from_utf8(api_key).log_err()
777 } else {
778 None
779 };
780 if let Some(api_key) = api_key {
781 *self.api_key.borrow_mut() = Some(api_key);
782 } else if self.api_key_editor.is_none() {
783 self.api_key_editor = Some(build_api_key_editor(cx));
784 cx.notify();
785 }
786 }
787
788 if self.editors.is_empty() {
789 self.new_conversation(cx);
790 }
791 }
792 }
793
794 fn icon_path(&self, cx: &WindowContext) -> Option<&'static str> {
795 settings::get::<AssistantSettings>(cx)
796 .button
797 .then(|| "icons/robot_14.svg")
798 }
799
800 fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
801 ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
802 }
803
804 fn should_change_position_on_event(event: &Self::Event) -> bool {
805 matches!(event, AssistantPanelEvent::DockPositionChanged)
806 }
807
808 fn should_activate_on_event(_: &Self::Event) -> bool {
809 false
810 }
811
812 fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
813 matches!(event, AssistantPanelEvent::Close)
814 }
815
816 fn has_focus(&self, _: &WindowContext) -> bool {
817 self.has_focus
818 }
819
820 fn is_focus_event(event: &Self::Event) -> bool {
821 matches!(event, AssistantPanelEvent::Focus)
822 }
823}
824
825enum ConversationEvent {
826 MessagesEdited,
827 SummaryChanged,
828 StreamedCompletion,
829}
830
831#[derive(Default)]
832struct Summary {
833 text: String,
834 done: bool,
835}
836
837struct Conversation {
838 buffer: ModelHandle<Buffer>,
839 message_anchors: Vec<MessageAnchor>,
840 messages_metadata: HashMap<MessageId, MessageMetadata>,
841 next_message_id: MessageId,
842 summary: Option<Summary>,
843 pending_summary: Task<Option<()>>,
844 completion_count: usize,
845 pending_completions: Vec<PendingCompletion>,
846 model: String,
847 token_count: Option<usize>,
848 max_token_count: usize,
849 pending_token_count: Task<Option<()>>,
850 api_key: Rc<RefCell<Option<String>>>,
851 pending_save: Task<Result<()>>,
852 path: Option<PathBuf>,
853 _subscriptions: Vec<Subscription>,
854}
855
856impl Entity for Conversation {
857 type Event = ConversationEvent;
858}
859
860impl Conversation {
861 fn new(
862 api_key: Rc<RefCell<Option<String>>>,
863 language_registry: Arc<LanguageRegistry>,
864 cx: &mut ModelContext<Self>,
865 ) -> Self {
866 let model = "gpt-3.5-turbo-0613";
867 let markdown = language_registry.language_for_name("Markdown");
868 let buffer = cx.add_model(|cx| {
869 let mut buffer = Buffer::new(0, "", cx);
870 buffer.set_language_registry(language_registry);
871 cx.spawn_weak(|buffer, mut cx| async move {
872 let markdown = markdown.await?;
873 let buffer = buffer
874 .upgrade(&cx)
875 .ok_or_else(|| anyhow!("buffer was dropped"))?;
876 buffer.update(&mut cx, |buffer, cx| {
877 buffer.set_language(Some(markdown), cx)
878 });
879 anyhow::Ok(())
880 })
881 .detach_and_log_err(cx);
882 buffer
883 });
884
885 let mut this = Self {
886 message_anchors: Default::default(),
887 messages_metadata: Default::default(),
888 next_message_id: Default::default(),
889 summary: None,
890 pending_summary: Task::ready(None),
891 completion_count: Default::default(),
892 pending_completions: Default::default(),
893 token_count: None,
894 max_token_count: tiktoken_rs::model::get_context_size(model),
895 pending_token_count: Task::ready(None),
896 model: model.into(),
897 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
898 pending_save: Task::ready(Ok(())),
899 path: None,
900 api_key,
901 buffer,
902 };
903 let message = MessageAnchor {
904 id: MessageId(post_inc(&mut this.next_message_id.0)),
905 start: language::Anchor::MIN,
906 };
907 this.message_anchors.push(message.clone());
908 this.messages_metadata.insert(
909 message.id,
910 MessageMetadata {
911 role: Role::User,
912 sent_at: Local::now(),
913 status: MessageStatus::Done,
914 },
915 );
916
917 this.count_remaining_tokens(cx);
918 this
919 }
920
921 fn serialize(&self, cx: &AppContext) -> SavedConversation {
922 SavedConversation {
923 zed: "conversation".into(),
924 version: SavedConversation::VERSION.into(),
925 text: self.buffer.read(cx).text(),
926 message_metadata: self.messages_metadata.clone(),
927 messages: self
928 .messages(cx)
929 .map(|message| SavedMessage {
930 id: message.id,
931 start: message.offset_range.start,
932 })
933 .collect(),
934 summary: self
935 .summary
936 .as_ref()
937 .map(|summary| summary.text.clone())
938 .unwrap_or_default(),
939 model: self.model.clone(),
940 }
941 }
942
943 fn deserialize(
944 saved_conversation: SavedConversation,
945 path: PathBuf,
946 api_key: Rc<RefCell<Option<String>>>,
947 language_registry: Arc<LanguageRegistry>,
948 cx: &mut ModelContext<Self>,
949 ) -> Self {
950 let model = saved_conversation.model;
951 let markdown = language_registry.language_for_name("Markdown");
952 let mut message_anchors = Vec::new();
953 let mut next_message_id = MessageId(0);
954 let buffer = cx.add_model(|cx| {
955 let mut buffer = Buffer::new(0, saved_conversation.text, cx);
956 for message in saved_conversation.messages {
957 message_anchors.push(MessageAnchor {
958 id: message.id,
959 start: buffer.anchor_before(message.start),
960 });
961 next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
962 }
963 buffer.set_language_registry(language_registry);
964 cx.spawn_weak(|buffer, mut cx| async move {
965 let markdown = markdown.await?;
966 let buffer = buffer
967 .upgrade(&cx)
968 .ok_or_else(|| anyhow!("buffer was dropped"))?;
969 buffer.update(&mut cx, |buffer, cx| {
970 buffer.set_language(Some(markdown), cx)
971 });
972 anyhow::Ok(())
973 })
974 .detach_and_log_err(cx);
975 buffer
976 });
977
978 let mut this = Self {
979 message_anchors,
980 messages_metadata: saved_conversation.message_metadata,
981 next_message_id,
982 summary: Some(Summary {
983 text: saved_conversation.summary,
984 done: true,
985 }),
986 pending_summary: Task::ready(None),
987 completion_count: Default::default(),
988 pending_completions: Default::default(),
989 token_count: None,
990 max_token_count: tiktoken_rs::model::get_context_size(&model),
991 pending_token_count: Task::ready(None),
992 model,
993 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
994 pending_save: Task::ready(Ok(())),
995 path: Some(path),
996 api_key,
997 buffer,
998 };
999 this.count_remaining_tokens(cx);
1000 this
1001 }
1002
1003 fn handle_buffer_event(
1004 &mut self,
1005 _: ModelHandle<Buffer>,
1006 event: &language::Event,
1007 cx: &mut ModelContext<Self>,
1008 ) {
1009 match event {
1010 language::Event::Edited => {
1011 self.count_remaining_tokens(cx);
1012 cx.emit(ConversationEvent::MessagesEdited);
1013 }
1014 _ => {}
1015 }
1016 }
1017
1018 fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
1019 let messages = self
1020 .messages(cx)
1021 .into_iter()
1022 .filter_map(|message| {
1023 Some(tiktoken_rs::ChatCompletionRequestMessage {
1024 role: match message.role {
1025 Role::User => "user".into(),
1026 Role::Assistant => "assistant".into(),
1027 Role::System => "system".into(),
1028 },
1029 content: self
1030 .buffer
1031 .read(cx)
1032 .text_for_range(message.offset_range)
1033 .collect(),
1034 name: None,
1035 })
1036 })
1037 .collect::<Vec<_>>();
1038 let model = self.model.clone();
1039 self.pending_token_count = cx.spawn_weak(|this, mut cx| {
1040 async move {
1041 cx.background().timer(Duration::from_millis(200)).await;
1042 let token_count = cx
1043 .background()
1044 .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
1045 .await?;
1046
1047 this.upgrade(&cx)
1048 .ok_or_else(|| anyhow!("conversation was dropped"))?
1049 .update(&mut cx, |this, cx| {
1050 this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
1051 this.token_count = Some(token_count);
1052 cx.notify()
1053 });
1054 anyhow::Ok(())
1055 }
1056 .log_err()
1057 });
1058 }
1059
1060 fn remaining_tokens(&self) -> Option<isize> {
1061 Some(self.max_token_count as isize - self.token_count? as isize)
1062 }
1063
1064 fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
1065 self.model = model;
1066 self.count_remaining_tokens(cx);
1067 cx.notify();
1068 }
1069
1070 fn assist(
1071 &mut self,
1072 selected_messages: HashSet<MessageId>,
1073 cx: &mut ModelContext<Self>,
1074 ) -> Vec<MessageAnchor> {
1075 let mut user_messages = Vec::new();
1076 let mut tasks = Vec::new();
1077
1078 let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
1079 message
1080 .start
1081 .is_valid(self.buffer.read(cx))
1082 .then_some(message.id)
1083 });
1084
1085 for selected_message_id in selected_messages {
1086 let selected_message_role =
1087 if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
1088 metadata.role
1089 } else {
1090 continue;
1091 };
1092
1093 if selected_message_role == Role::Assistant {
1094 if let Some(user_message) = self.insert_message_after(
1095 selected_message_id,
1096 Role::User,
1097 MessageStatus::Done,
1098 cx,
1099 ) {
1100 user_messages.push(user_message);
1101 } else {
1102 continue;
1103 }
1104 } else {
1105 let request = OpenAIRequest {
1106 model: self.model.clone(),
1107 messages: self
1108 .messages(cx)
1109 .filter(|message| matches!(message.status, MessageStatus::Done))
1110 .flat_map(|message| {
1111 let mut system_message = None;
1112 if message.id == selected_message_id {
1113 system_message = Some(RequestMessage {
1114 role: Role::System,
1115 content: concat!(
1116 "Treat the following messages as additional knowledge you have learned about, ",
1117 "but act as if they were not part of this conversation. That is, treat them ",
1118 "as if the user didn't see them and couldn't possibly inquire about them."
1119 ).into()
1120 });
1121 }
1122
1123 Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
1124 })
1125 .chain(Some(RequestMessage {
1126 role: Role::System,
1127 content: format!(
1128 "Direct your reply to message with id {}. Do not include a [Message X] header.",
1129 selected_message_id.0
1130 ),
1131 }))
1132 .collect(),
1133 stream: true,
1134 };
1135
1136 let Some(api_key) = self.api_key.borrow().clone() else { continue };
1137 let stream = stream_completion(api_key, cx.background().clone(), request);
1138 let assistant_message = self
1139 .insert_message_after(
1140 selected_message_id,
1141 Role::Assistant,
1142 MessageStatus::Pending,
1143 cx,
1144 )
1145 .unwrap();
1146
1147 // Queue up the user's next reply
1148 if Some(selected_message_id) == last_message_id {
1149 let user_message = self
1150 .insert_message_after(
1151 assistant_message.id,
1152 Role::User,
1153 MessageStatus::Done,
1154 cx,
1155 )
1156 .unwrap();
1157 user_messages.push(user_message);
1158 }
1159
1160 tasks.push(cx.spawn_weak({
1161 |this, mut cx| async move {
1162 let assistant_message_id = assistant_message.id;
1163 let stream_completion = async {
1164 let mut messages = stream.await?;
1165
1166 while let Some(message) = messages.next().await {
1167 let mut message = message?;
1168 if let Some(choice) = message.choices.pop() {
1169 this.upgrade(&cx)
1170 .ok_or_else(|| anyhow!("conversation was dropped"))?
1171 .update(&mut cx, |this, cx| {
1172 let text: Arc<str> = choice.delta.content?.into();
1173 let message_ix = this.message_anchors.iter().position(
1174 |message| message.id == assistant_message_id,
1175 )?;
1176 this.buffer.update(cx, |buffer, cx| {
1177 let offset = this.message_anchors[message_ix + 1..]
1178 .iter()
1179 .find(|message| message.start.is_valid(buffer))
1180 .map_or(buffer.len(), |message| {
1181 message
1182 .start
1183 .to_offset(buffer)
1184 .saturating_sub(1)
1185 });
1186 buffer.edit([(offset..offset, text)], None, cx);
1187 });
1188 cx.emit(ConversationEvent::StreamedCompletion);
1189
1190 Some(())
1191 });
1192 }
1193 smol::future::yield_now().await;
1194 }
1195
1196 this.upgrade(&cx)
1197 .ok_or_else(|| anyhow!("conversation was dropped"))?
1198 .update(&mut cx, |this, cx| {
1199 this.pending_completions.retain(|completion| {
1200 completion.id != this.completion_count
1201 });
1202 this.summarize(cx);
1203 });
1204
1205 anyhow::Ok(())
1206 };
1207
1208 let result = stream_completion.await;
1209 if let Some(this) = this.upgrade(&cx) {
1210 this.update(&mut cx, |this, cx| {
1211 if let Some(metadata) =
1212 this.messages_metadata.get_mut(&assistant_message.id)
1213 {
1214 match result {
1215 Ok(_) => {
1216 metadata.status = MessageStatus::Done;
1217 }
1218 Err(error) => {
1219 metadata.status = MessageStatus::Error(
1220 error.to_string().trim().into(),
1221 );
1222 }
1223 }
1224 cx.notify();
1225 }
1226 });
1227 }
1228 }
1229 }));
1230 }
1231 }
1232
1233 if !tasks.is_empty() {
1234 self.pending_completions.push(PendingCompletion {
1235 id: post_inc(&mut self.completion_count),
1236 _tasks: tasks,
1237 });
1238 }
1239
1240 user_messages
1241 }
1242
1243 fn cancel_last_assist(&mut self) -> bool {
1244 self.pending_completions.pop().is_some()
1245 }
1246
1247 fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
1248 for id in ids {
1249 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
1250 metadata.role.cycle();
1251 cx.emit(ConversationEvent::MessagesEdited);
1252 cx.notify();
1253 }
1254 }
1255 }
1256
1257 fn insert_message_after(
1258 &mut self,
1259 message_id: MessageId,
1260 role: Role,
1261 status: MessageStatus,
1262 cx: &mut ModelContext<Self>,
1263 ) -> Option<MessageAnchor> {
1264 if let Some(prev_message_ix) = self
1265 .message_anchors
1266 .iter()
1267 .position(|message| message.id == message_id)
1268 {
1269 // Find the next valid message after the one we were given.
1270 let mut next_message_ix = prev_message_ix + 1;
1271 while let Some(next_message) = self.message_anchors.get(next_message_ix) {
1272 if next_message.start.is_valid(self.buffer.read(cx)) {
1273 break;
1274 }
1275 next_message_ix += 1;
1276 }
1277
1278 let start = self.buffer.update(cx, |buffer, cx| {
1279 let offset = self
1280 .message_anchors
1281 .get(next_message_ix)
1282 .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
1283 buffer.edit([(offset..offset, "\n")], None, cx);
1284 buffer.anchor_before(offset + 1)
1285 });
1286 let message = MessageAnchor {
1287 id: MessageId(post_inc(&mut self.next_message_id.0)),
1288 start,
1289 };
1290 self.message_anchors
1291 .insert(next_message_ix, message.clone());
1292 self.messages_metadata.insert(
1293 message.id,
1294 MessageMetadata {
1295 role,
1296 sent_at: Local::now(),
1297 status,
1298 },
1299 );
1300 cx.emit(ConversationEvent::MessagesEdited);
1301 Some(message)
1302 } else {
1303 None
1304 }
1305 }
1306
1307 fn split_message(
1308 &mut self,
1309 range: Range<usize>,
1310 cx: &mut ModelContext<Self>,
1311 ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
1312 let start_message = self.message_for_offset(range.start, cx);
1313 let end_message = self.message_for_offset(range.end, cx);
1314 if let Some((start_message, end_message)) = start_message.zip(end_message) {
1315 // Prevent splitting when range spans multiple messages.
1316 if start_message.id != end_message.id {
1317 return (None, None);
1318 }
1319
1320 let message = start_message;
1321 let role = message.role;
1322 let mut edited_buffer = false;
1323
1324 let mut suffix_start = None;
1325 if range.start > message.offset_range.start && range.end < message.offset_range.end - 1
1326 {
1327 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
1328 suffix_start = Some(range.end + 1);
1329 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
1330 suffix_start = Some(range.end);
1331 }
1332 }
1333
1334 let suffix = if let Some(suffix_start) = suffix_start {
1335 MessageAnchor {
1336 id: MessageId(post_inc(&mut self.next_message_id.0)),
1337 start: self.buffer.read(cx).anchor_before(suffix_start),
1338 }
1339 } else {
1340 self.buffer.update(cx, |buffer, cx| {
1341 buffer.edit([(range.end..range.end, "\n")], None, cx);
1342 });
1343 edited_buffer = true;
1344 MessageAnchor {
1345 id: MessageId(post_inc(&mut self.next_message_id.0)),
1346 start: self.buffer.read(cx).anchor_before(range.end + 1),
1347 }
1348 };
1349
1350 self.message_anchors
1351 .insert(message.index_range.end + 1, suffix.clone());
1352 self.messages_metadata.insert(
1353 suffix.id,
1354 MessageMetadata {
1355 role,
1356 sent_at: Local::now(),
1357 status: MessageStatus::Done,
1358 },
1359 );
1360
1361 let new_messages =
1362 if range.start == range.end || range.start == message.offset_range.start {
1363 (None, Some(suffix))
1364 } else {
1365 let mut prefix_end = None;
1366 if range.start > message.offset_range.start
1367 && range.end < message.offset_range.end - 1
1368 {
1369 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
1370 prefix_end = Some(range.start + 1);
1371 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
1372 == Some('\n')
1373 {
1374 prefix_end = Some(range.start);
1375 }
1376 }
1377
1378 let selection = if let Some(prefix_end) = prefix_end {
1379 cx.emit(ConversationEvent::MessagesEdited);
1380 MessageAnchor {
1381 id: MessageId(post_inc(&mut self.next_message_id.0)),
1382 start: self.buffer.read(cx).anchor_before(prefix_end),
1383 }
1384 } else {
1385 self.buffer.update(cx, |buffer, cx| {
1386 buffer.edit([(range.start..range.start, "\n")], None, cx)
1387 });
1388 edited_buffer = true;
1389 MessageAnchor {
1390 id: MessageId(post_inc(&mut self.next_message_id.0)),
1391 start: self.buffer.read(cx).anchor_before(range.end + 1),
1392 }
1393 };
1394
1395 self.message_anchors
1396 .insert(message.index_range.end + 1, selection.clone());
1397 self.messages_metadata.insert(
1398 selection.id,
1399 MessageMetadata {
1400 role,
1401 sent_at: Local::now(),
1402 status: MessageStatus::Done,
1403 },
1404 );
1405 (Some(selection), Some(suffix))
1406 };
1407
1408 if !edited_buffer {
1409 cx.emit(ConversationEvent::MessagesEdited);
1410 }
1411 new_messages
1412 } else {
1413 (None, None)
1414 }
1415 }
1416
1417 fn summarize(&mut self, cx: &mut ModelContext<Self>) {
1418 if self.message_anchors.len() >= 2 && self.summary.is_none() {
1419 let api_key = self.api_key.borrow().clone();
1420 if let Some(api_key) = api_key {
1421 let messages = self
1422 .messages(cx)
1423 .take(2)
1424 .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
1425 .chain(Some(RequestMessage {
1426 role: Role::User,
1427 content:
1428 "Summarize the conversation into a short title without punctuation"
1429 .into(),
1430 }));
1431 let request = OpenAIRequest {
1432 model: self.model.clone(),
1433 messages: messages.collect(),
1434 stream: true,
1435 };
1436
1437 let stream = stream_completion(api_key, cx.background().clone(), request);
1438 self.pending_summary = cx.spawn(|this, mut cx| {
1439 async move {
1440 let mut messages = stream.await?;
1441
1442 while let Some(message) = messages.next().await {
1443 let mut message = message?;
1444 if let Some(choice) = message.choices.pop() {
1445 let text = choice.delta.content.unwrap_or_default();
1446 this.update(&mut cx, |this, cx| {
1447 this.summary
1448 .get_or_insert(Default::default())
1449 .text
1450 .push_str(&text);
1451 cx.emit(ConversationEvent::SummaryChanged);
1452 });
1453 }
1454 }
1455
1456 this.update(&mut cx, |this, cx| {
1457 if let Some(summary) = this.summary.as_mut() {
1458 summary.done = true;
1459 cx.emit(ConversationEvent::SummaryChanged);
1460 }
1461 });
1462
1463 anyhow::Ok(())
1464 }
1465 .log_err()
1466 });
1467 }
1468 }
1469 }
1470
1471 fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
1472 self.messages_for_offsets([offset], cx).pop()
1473 }
1474
1475 fn messages_for_offsets(
1476 &self,
1477 offsets: impl IntoIterator<Item = usize>,
1478 cx: &AppContext,
1479 ) -> Vec<Message> {
1480 let mut result = Vec::new();
1481
1482 let mut messages = self.messages(cx).peekable();
1483 let mut offsets = offsets.into_iter().peekable();
1484 let mut current_message = messages.next();
1485 while let Some(offset) = offsets.next() {
1486 // Locate the message that contains the offset.
1487 while current_message.as_ref().map_or(false, |message| {
1488 !message.offset_range.contains(&offset) && messages.peek().is_some()
1489 }) {
1490 current_message = messages.next();
1491 }
1492 let Some(message) = current_message.as_ref() else { break };
1493
1494 // Skip offsets that are in the same message.
1495 while offsets.peek().map_or(false, |offset| {
1496 message.offset_range.contains(offset) || messages.peek().is_none()
1497 }) {
1498 offsets.next();
1499 }
1500
1501 result.push(message.clone());
1502 }
1503 result
1504 }
1505
1506 fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
1507 let buffer = self.buffer.read(cx);
1508 let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
1509 iter::from_fn(move || {
1510 while let Some((start_ix, message_anchor)) = message_anchors.next() {
1511 let metadata = self.messages_metadata.get(&message_anchor.id)?;
1512 let message_start = message_anchor.start.to_offset(buffer);
1513 let mut message_end = None;
1514 let mut end_ix = start_ix;
1515 while let Some((_, next_message)) = message_anchors.peek() {
1516 if next_message.start.is_valid(buffer) {
1517 message_end = Some(next_message.start);
1518 break;
1519 } else {
1520 end_ix += 1;
1521 message_anchors.next();
1522 }
1523 }
1524 let message_end = message_end
1525 .unwrap_or(language::Anchor::MAX)
1526 .to_offset(buffer);
1527 return Some(Message {
1528 index_range: start_ix..end_ix,
1529 offset_range: message_start..message_end,
1530 id: message_anchor.id,
1531 anchor: message_anchor.start,
1532 role: metadata.role,
1533 sent_at: metadata.sent_at,
1534 status: metadata.status.clone(),
1535 });
1536 }
1537 None
1538 })
1539 }
1540
1541 fn save(
1542 &mut self,
1543 debounce: Option<Duration>,
1544 fs: Arc<dyn Fs>,
1545 cx: &mut ModelContext<Conversation>,
1546 ) {
1547 self.pending_save = cx.spawn(|this, mut cx| async move {
1548 if let Some(debounce) = debounce {
1549 cx.background().timer(debounce).await;
1550 }
1551
1552 let (old_path, summary) = this.read_with(&cx, |this, _| {
1553 let path = this.path.clone();
1554 let summary = if let Some(summary) = this.summary.as_ref() {
1555 if summary.done {
1556 Some(summary.text.clone())
1557 } else {
1558 None
1559 }
1560 } else {
1561 None
1562 };
1563 (path, summary)
1564 });
1565
1566 if let Some(summary) = summary {
1567 let conversation = this.read_with(&cx, |this, cx| this.serialize(cx));
1568 let path = if let Some(old_path) = old_path {
1569 old_path
1570 } else {
1571 let mut discriminant = 1;
1572 let mut new_path;
1573 loop {
1574 new_path = CONVERSATIONS_DIR.join(&format!(
1575 "{} - {}.zed.json",
1576 summary.trim(),
1577 discriminant
1578 ));
1579 if fs.is_file(&new_path).await {
1580 discriminant += 1;
1581 } else {
1582 break;
1583 }
1584 }
1585 new_path
1586 };
1587
1588 fs.create_dir(CONVERSATIONS_DIR.as_ref()).await?;
1589 fs.atomic_write(path.clone(), serde_json::to_string(&conversation).unwrap())
1590 .await?;
1591 this.update(&mut cx, |this, _| this.path = Some(path));
1592 }
1593
1594 Ok(())
1595 });
1596 }
1597}
1598
1599struct PendingCompletion {
1600 id: usize,
1601 _tasks: Vec<Task<()>>,
1602}
1603
1604enum ConversationEditorEvent {
1605 TabContentChanged,
1606}
1607
1608#[derive(Copy, Clone, Debug, PartialEq)]
1609struct ScrollPosition {
1610 offset_before_cursor: Vector2F,
1611 cursor: Anchor,
1612}
1613
1614struct ConversationEditor {
1615 conversation: ModelHandle<Conversation>,
1616 fs: Arc<dyn Fs>,
1617 editor: ViewHandle<Editor>,
1618 blocks: HashSet<BlockId>,
1619 scroll_position: Option<ScrollPosition>,
1620 _subscriptions: Vec<Subscription>,
1621}
1622
1623impl ConversationEditor {
1624 fn new(
1625 api_key: Rc<RefCell<Option<String>>>,
1626 language_registry: Arc<LanguageRegistry>,
1627 fs: Arc<dyn Fs>,
1628 cx: &mut ViewContext<Self>,
1629 ) -> Self {
1630 let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
1631 Self::for_conversation(conversation, fs, cx)
1632 }
1633
1634 fn for_conversation(
1635 conversation: ModelHandle<Conversation>,
1636 fs: Arc<dyn Fs>,
1637 cx: &mut ViewContext<Self>,
1638 ) -> Self {
1639 let editor = cx.add_view(|cx| {
1640 let mut editor = Editor::for_buffer(conversation.read(cx).buffer.clone(), None, cx);
1641 editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
1642 editor.set_show_gutter(false, cx);
1643 editor.set_show_wrap_guides(false, cx);
1644 editor
1645 });
1646
1647 let _subscriptions = vec![
1648 cx.observe(&conversation, |_, _, cx| cx.notify()),
1649 cx.subscribe(&conversation, Self::handle_conversation_event),
1650 cx.subscribe(&editor, Self::handle_editor_event),
1651 ];
1652
1653 let mut this = Self {
1654 conversation,
1655 editor,
1656 blocks: Default::default(),
1657 scroll_position: None,
1658 fs,
1659 _subscriptions,
1660 };
1661 this.update_message_headers(cx);
1662 this
1663 }
1664
1665 fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
1666 let cursors = self.cursors(cx);
1667
1668 let user_messages = self.conversation.update(cx, |conversation, cx| {
1669 let selected_messages = conversation
1670 .messages_for_offsets(cursors, cx)
1671 .into_iter()
1672 .map(|message| message.id)
1673 .collect();
1674 conversation.assist(selected_messages, cx)
1675 });
1676 let new_selections = user_messages
1677 .iter()
1678 .map(|message| {
1679 let cursor = message
1680 .start
1681 .to_offset(self.conversation.read(cx).buffer.read(cx));
1682 cursor..cursor
1683 })
1684 .collect::<Vec<_>>();
1685 if !new_selections.is_empty() {
1686 self.editor.update(cx, |editor, cx| {
1687 editor.change_selections(
1688 Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
1689 cx,
1690 |selections| selections.select_ranges(new_selections),
1691 );
1692 });
1693 // Avoid scrolling to the new cursor position so the assistant's output is stable.
1694 cx.defer(|this, _| this.scroll_position = None);
1695 }
1696 }
1697
1698 fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
1699 if !self
1700 .conversation
1701 .update(cx, |conversation, _| conversation.cancel_last_assist())
1702 {
1703 cx.propagate_action();
1704 }
1705 }
1706
1707 fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
1708 let cursors = self.cursors(cx);
1709 self.conversation.update(cx, |conversation, cx| {
1710 let messages = conversation
1711 .messages_for_offsets(cursors, cx)
1712 .into_iter()
1713 .map(|message| message.id)
1714 .collect();
1715 conversation.cycle_message_roles(messages, cx)
1716 });
1717 }
1718
1719 fn cursors(&self, cx: &AppContext) -> Vec<usize> {
1720 let selections = self.editor.read(cx).selections.all::<usize>(cx);
1721 selections
1722 .into_iter()
1723 .map(|selection| selection.head())
1724 .collect()
1725 }
1726
1727 fn handle_conversation_event(
1728 &mut self,
1729 _: ModelHandle<Conversation>,
1730 event: &ConversationEvent,
1731 cx: &mut ViewContext<Self>,
1732 ) {
1733 match event {
1734 ConversationEvent::MessagesEdited => {
1735 self.update_message_headers(cx);
1736 self.conversation.update(cx, |conversation, cx| {
1737 conversation.save(Some(Duration::from_millis(500)), self.fs.clone(), cx);
1738 });
1739 }
1740 ConversationEvent::SummaryChanged => {
1741 cx.emit(ConversationEditorEvent::TabContentChanged);
1742 self.conversation.update(cx, |conversation, cx| {
1743 conversation.save(None, self.fs.clone(), cx);
1744 });
1745 }
1746 ConversationEvent::StreamedCompletion => {
1747 self.editor.update(cx, |editor, cx| {
1748 if let Some(scroll_position) = self.scroll_position {
1749 let snapshot = editor.snapshot(cx);
1750 let cursor_point = scroll_position.cursor.to_display_point(&snapshot);
1751 let scroll_top =
1752 cursor_point.row() as f32 - scroll_position.offset_before_cursor.y();
1753 editor.set_scroll_position(
1754 vec2f(scroll_position.offset_before_cursor.x(), scroll_top),
1755 cx,
1756 );
1757 }
1758 });
1759 }
1760 }
1761 }
1762
1763 fn handle_editor_event(
1764 &mut self,
1765 _: ViewHandle<Editor>,
1766 event: &editor::Event,
1767 cx: &mut ViewContext<Self>,
1768 ) {
1769 match event {
1770 editor::Event::ScrollPositionChanged { autoscroll, .. } => {
1771 let cursor_scroll_position = self.cursor_scroll_position(cx);
1772 if *autoscroll {
1773 self.scroll_position = cursor_scroll_position;
1774 } else if self.scroll_position != cursor_scroll_position {
1775 self.scroll_position = None;
1776 }
1777 }
1778 editor::Event::SelectionsChanged { .. } => {
1779 self.scroll_position = self.cursor_scroll_position(cx);
1780 }
1781 _ => {}
1782 }
1783 }
1784
1785 fn cursor_scroll_position(&self, cx: &mut ViewContext<Self>) -> Option<ScrollPosition> {
1786 self.editor.update(cx, |editor, cx| {
1787 let snapshot = editor.snapshot(cx);
1788 let cursor = editor.selections.newest_anchor().head();
1789 let cursor_row = cursor.to_display_point(&snapshot.display_snapshot).row() as f32;
1790 let scroll_position = editor
1791 .scroll_manager
1792 .anchor()
1793 .scroll_position(&snapshot.display_snapshot);
1794
1795 let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1796 if (scroll_position.y()..scroll_bottom).contains(&cursor_row) {
1797 Some(ScrollPosition {
1798 cursor,
1799 offset_before_cursor: vec2f(
1800 scroll_position.x(),
1801 cursor_row - scroll_position.y(),
1802 ),
1803 })
1804 } else {
1805 None
1806 }
1807 })
1808 }
1809
1810 fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
1811 self.editor.update(cx, |editor, cx| {
1812 let buffer = editor.buffer().read(cx).snapshot(cx);
1813 let excerpt_id = *buffer.as_singleton().unwrap().0;
1814 let old_blocks = std::mem::take(&mut self.blocks);
1815 let new_blocks = self
1816 .conversation
1817 .read(cx)
1818 .messages(cx)
1819 .map(|message| BlockProperties {
1820 position: buffer.anchor_in_excerpt(excerpt_id, message.anchor),
1821 height: 2,
1822 style: BlockStyle::Sticky,
1823 render: Arc::new({
1824 let conversation = self.conversation.clone();
1825 // let metadata = message.metadata.clone();
1826 // let message = message.clone();
1827 move |cx| {
1828 enum Sender {}
1829 enum ErrorTooltip {}
1830
1831 let theme = theme::current(cx);
1832 let style = &theme.assistant;
1833 let message_id = message.id;
1834 let sender = MouseEventHandler::<Sender, _>::new(
1835 message_id.0,
1836 cx,
1837 |state, _| match message.role {
1838 Role::User => {
1839 let style = style.user_sender.style_for(state);
1840 Label::new("You", style.text.clone())
1841 .contained()
1842 .with_style(style.container)
1843 }
1844 Role::Assistant => {
1845 let style = style.assistant_sender.style_for(state);
1846 Label::new("Assistant", style.text.clone())
1847 .contained()
1848 .with_style(style.container)
1849 }
1850 Role::System => {
1851 let style = style.system_sender.style_for(state);
1852 Label::new("System", style.text.clone())
1853 .contained()
1854 .with_style(style.container)
1855 }
1856 },
1857 )
1858 .with_cursor_style(CursorStyle::PointingHand)
1859 .on_down(MouseButton::Left, {
1860 let conversation = conversation.clone();
1861 move |_, _, cx| {
1862 conversation.update(cx, |conversation, cx| {
1863 conversation.cycle_message_roles(
1864 HashSet::from_iter(Some(message_id)),
1865 cx,
1866 )
1867 })
1868 }
1869 });
1870
1871 Flex::row()
1872 .with_child(sender.aligned())
1873 .with_child(
1874 Label::new(
1875 message.sent_at.format("%I:%M%P").to_string(),
1876 style.sent_at.text.clone(),
1877 )
1878 .contained()
1879 .with_style(style.sent_at.container)
1880 .aligned(),
1881 )
1882 .with_children(
1883 if let MessageStatus::Error(error) = &message.status {
1884 Some(
1885 Svg::new("icons/circle_x_mark_12.svg")
1886 .with_color(style.error_icon.color)
1887 .constrained()
1888 .with_width(style.error_icon.width)
1889 .contained()
1890 .with_style(style.error_icon.container)
1891 .with_tooltip::<ErrorTooltip>(
1892 message_id.0,
1893 error.to_string(),
1894 None,
1895 theme.tooltip.clone(),
1896 cx,
1897 )
1898 .aligned(),
1899 )
1900 } else {
1901 None
1902 },
1903 )
1904 .aligned()
1905 .left()
1906 .contained()
1907 .with_style(style.message_header)
1908 .into_any()
1909 }
1910 }),
1911 disposition: BlockDisposition::Above,
1912 })
1913 .collect::<Vec<_>>();
1914
1915 editor.remove_blocks(old_blocks, None, cx);
1916 let ids = editor.insert_blocks(new_blocks, None, cx);
1917 self.blocks = HashSet::from_iter(ids);
1918 });
1919 }
1920
1921 fn quote_selection(
1922 workspace: &mut Workspace,
1923 _: &QuoteSelection,
1924 cx: &mut ViewContext<Workspace>,
1925 ) {
1926 let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1927 return;
1928 };
1929 let Some(editor) = workspace.active_item(cx).and_then(|item| item.act_as::<Editor>(cx)) else {
1930 return;
1931 };
1932
1933 let text = editor.read_with(cx, |editor, cx| {
1934 let range = editor.selections.newest::<usize>(cx).range();
1935 let buffer = editor.buffer().read(cx).snapshot(cx);
1936 let start_language = buffer.language_at(range.start);
1937 let end_language = buffer.language_at(range.end);
1938 let language_name = if start_language == end_language {
1939 start_language.map(|language| language.name())
1940 } else {
1941 None
1942 };
1943 let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1944
1945 let selected_text = buffer.text_for_range(range).collect::<String>();
1946 if selected_text.is_empty() {
1947 None
1948 } else {
1949 Some(if language_name == "markdown" {
1950 selected_text
1951 .lines()
1952 .map(|line| format!("> {}", line))
1953 .collect::<Vec<_>>()
1954 .join("\n")
1955 } else {
1956 format!("```{language_name}\n{selected_text}\n```")
1957 })
1958 }
1959 });
1960
1961 // Activate the panel
1962 if !panel.read(cx).has_focus(cx) {
1963 workspace.toggle_panel_focus::<AssistantPanel>(cx);
1964 }
1965
1966 if let Some(text) = text {
1967 panel.update(cx, |panel, cx| {
1968 let conversation = panel
1969 .active_editor()
1970 .cloned()
1971 .unwrap_or_else(|| panel.new_conversation(cx));
1972 conversation.update(cx, |conversation, cx| {
1973 conversation
1974 .editor
1975 .update(cx, |editor, cx| editor.insert(&text, cx))
1976 });
1977 });
1978 }
1979 }
1980
1981 fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1982 let editor = self.editor.read(cx);
1983 let conversation = self.conversation.read(cx);
1984 if editor.selections.count() == 1 {
1985 let selection = editor.selections.newest::<usize>(cx);
1986 let mut copied_text = String::new();
1987 let mut spanned_messages = 0;
1988 for message in conversation.messages(cx) {
1989 if message.offset_range.start >= selection.range().end {
1990 break;
1991 } else if message.offset_range.end >= selection.range().start {
1992 let range = cmp::max(message.offset_range.start, selection.range().start)
1993 ..cmp::min(message.offset_range.end, selection.range().end);
1994 if !range.is_empty() {
1995 spanned_messages += 1;
1996 write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
1997 for chunk in conversation.buffer.read(cx).text_for_range(range) {
1998 copied_text.push_str(&chunk);
1999 }
2000 copied_text.push('\n');
2001 }
2002 }
2003 }
2004
2005 if spanned_messages > 1 {
2006 cx.platform()
2007 .write_to_clipboard(ClipboardItem::new(copied_text));
2008 return;
2009 }
2010 }
2011
2012 cx.propagate_action();
2013 }
2014
2015 fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
2016 self.conversation.update(cx, |conversation, cx| {
2017 let selections = self.editor.read(cx).selections.disjoint_anchors();
2018 for selection in selections.into_iter() {
2019 let buffer = self.editor.read(cx).buffer().read(cx).snapshot(cx);
2020 let range = selection
2021 .map(|endpoint| endpoint.to_offset(&buffer))
2022 .range();
2023 conversation.split_message(range, cx);
2024 }
2025 });
2026 }
2027
2028 fn save(&mut self, _: &Save, cx: &mut ViewContext<Self>) {
2029 self.conversation.update(cx, |conversation, cx| {
2030 conversation.save(None, self.fs.clone(), cx)
2031 });
2032 }
2033
2034 fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
2035 self.conversation.update(cx, |conversation, cx| {
2036 let new_model = match conversation.model.as_str() {
2037 "gpt-4-0613" => "gpt-3.5-turbo-0613",
2038 _ => "gpt-4-0613",
2039 };
2040 conversation.set_model(new_model.into(), cx);
2041 });
2042 }
2043
2044 fn title(&self, cx: &AppContext) -> String {
2045 self.conversation
2046 .read(cx)
2047 .summary
2048 .as_ref()
2049 .map(|summary| summary.text.clone())
2050 .unwrap_or_else(|| "New Conversation".into())
2051 }
2052
2053 fn render_current_model(
2054 &self,
2055 style: &AssistantStyle,
2056 cx: &mut ViewContext<Self>,
2057 ) -> impl Element<Self> {
2058 enum Model {}
2059
2060 MouseEventHandler::<Model, _>::new(0, cx, |state, cx| {
2061 let style = style.model.style_for(state);
2062 Label::new(self.conversation.read(cx).model.clone(), style.text.clone())
2063 .contained()
2064 .with_style(style.container)
2065 })
2066 .with_cursor_style(CursorStyle::PointingHand)
2067 .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx))
2068 }
2069
2070 fn render_remaining_tokens(
2071 &self,
2072 style: &AssistantStyle,
2073 cx: &mut ViewContext<Self>,
2074 ) -> Option<impl Element<Self>> {
2075 let remaining_tokens = self.conversation.read(cx).remaining_tokens()?;
2076 let remaining_tokens_style = if remaining_tokens <= 0 {
2077 &style.no_remaining_tokens
2078 } else if remaining_tokens <= 500 {
2079 &style.low_remaining_tokens
2080 } else {
2081 &style.remaining_tokens
2082 };
2083 Some(
2084 Label::new(
2085 remaining_tokens.to_string(),
2086 remaining_tokens_style.text.clone(),
2087 )
2088 .contained()
2089 .with_style(remaining_tokens_style.container),
2090 )
2091 }
2092}
2093
2094impl Entity for ConversationEditor {
2095 type Event = ConversationEditorEvent;
2096}
2097
2098impl View for ConversationEditor {
2099 fn ui_name() -> &'static str {
2100 "ConversationEditor"
2101 }
2102
2103 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
2104 let theme = &theme::current(cx).assistant;
2105 Stack::new()
2106 .with_child(
2107 ChildView::new(&self.editor, cx)
2108 .contained()
2109 .with_style(theme.container),
2110 )
2111 .with_child(
2112 Flex::row()
2113 .with_child(self.render_current_model(theme, cx))
2114 .with_children(self.render_remaining_tokens(theme, cx))
2115 .aligned()
2116 .top()
2117 .right(),
2118 )
2119 .into_any()
2120 }
2121
2122 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
2123 if cx.is_self_focused() {
2124 cx.focus(&self.editor);
2125 }
2126 }
2127}
2128
2129#[derive(Clone, Debug)]
2130struct MessageAnchor {
2131 id: MessageId,
2132 start: language::Anchor,
2133}
2134
2135#[derive(Clone, Debug)]
2136pub struct Message {
2137 offset_range: Range<usize>,
2138 index_range: Range<usize>,
2139 id: MessageId,
2140 anchor: language::Anchor,
2141 role: Role,
2142 sent_at: DateTime<Local>,
2143 status: MessageStatus,
2144}
2145
2146impl Message {
2147 fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
2148 let mut content = format!("[Message {}]\n", self.id.0).to_string();
2149 content.extend(buffer.text_for_range(self.offset_range.clone()));
2150 RequestMessage {
2151 role: self.role,
2152 content: content.trim_end().into(),
2153 }
2154 }
2155}
2156
2157async fn stream_completion(
2158 api_key: String,
2159 executor: Arc<Background>,
2160 mut request: OpenAIRequest,
2161) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
2162 request.stream = true;
2163
2164 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
2165
2166 let json_data = serde_json::to_string(&request)?;
2167 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
2168 .header("Content-Type", "application/json")
2169 .header("Authorization", format!("Bearer {}", api_key))
2170 .body(json_data)?
2171 .send_async()
2172 .await?;
2173
2174 let status = response.status();
2175 if status == StatusCode::OK {
2176 executor
2177 .spawn(async move {
2178 let mut lines = BufReader::new(response.body_mut()).lines();
2179
2180 fn parse_line(
2181 line: Result<String, io::Error>,
2182 ) -> Result<Option<OpenAIResponseStreamEvent>> {
2183 if let Some(data) = line?.strip_prefix("data: ") {
2184 let event = serde_json::from_str(&data)?;
2185 Ok(Some(event))
2186 } else {
2187 Ok(None)
2188 }
2189 }
2190
2191 while let Some(line) = lines.next().await {
2192 if let Some(event) = parse_line(line).transpose() {
2193 let done = event.as_ref().map_or(false, |event| {
2194 event
2195 .choices
2196 .last()
2197 .map_or(false, |choice| choice.finish_reason.is_some())
2198 });
2199 if tx.unbounded_send(event).is_err() {
2200 break;
2201 }
2202
2203 if done {
2204 break;
2205 }
2206 }
2207 }
2208
2209 anyhow::Ok(())
2210 })
2211 .detach();
2212
2213 Ok(rx)
2214 } else {
2215 let mut body = String::new();
2216 response.body_mut().read_to_string(&mut body).await?;
2217
2218 #[derive(Deserialize)]
2219 struct OpenAIResponse {
2220 error: OpenAIError,
2221 }
2222
2223 #[derive(Deserialize)]
2224 struct OpenAIError {
2225 message: String,
2226 }
2227
2228 match serde_json::from_str::<OpenAIResponse>(&body) {
2229 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
2230 "Failed to connect to OpenAI API: {}",
2231 response.error.message,
2232 )),
2233
2234 _ => Err(anyhow!(
2235 "Failed to connect to OpenAI API: {} {}",
2236 response.status(),
2237 body,
2238 )),
2239 }
2240 }
2241}
2242
2243#[cfg(test)]
2244mod tests {
2245 use super::*;
2246 use crate::MessageId;
2247 use gpui::AppContext;
2248
2249 #[gpui::test]
2250 fn test_inserting_and_removing_messages(cx: &mut AppContext) {
2251 let registry = Arc::new(LanguageRegistry::test());
2252 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2253 let buffer = conversation.read(cx).buffer.clone();
2254
2255 let message_1 = conversation.read(cx).message_anchors[0].clone();
2256 assert_eq!(
2257 messages(&conversation, cx),
2258 vec![(message_1.id, Role::User, 0..0)]
2259 );
2260
2261 let message_2 = conversation.update(cx, |conversation, cx| {
2262 conversation
2263 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
2264 .unwrap()
2265 });
2266 assert_eq!(
2267 messages(&conversation, cx),
2268 vec![
2269 (message_1.id, Role::User, 0..1),
2270 (message_2.id, Role::Assistant, 1..1)
2271 ]
2272 );
2273
2274 buffer.update(cx, |buffer, cx| {
2275 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
2276 });
2277 assert_eq!(
2278 messages(&conversation, cx),
2279 vec![
2280 (message_1.id, Role::User, 0..2),
2281 (message_2.id, Role::Assistant, 2..3)
2282 ]
2283 );
2284
2285 let message_3 = conversation.update(cx, |conversation, cx| {
2286 conversation
2287 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2288 .unwrap()
2289 });
2290 assert_eq!(
2291 messages(&conversation, cx),
2292 vec![
2293 (message_1.id, Role::User, 0..2),
2294 (message_2.id, Role::Assistant, 2..4),
2295 (message_3.id, Role::User, 4..4)
2296 ]
2297 );
2298
2299 let message_4 = conversation.update(cx, |conversation, cx| {
2300 conversation
2301 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2302 .unwrap()
2303 });
2304 assert_eq!(
2305 messages(&conversation, cx),
2306 vec![
2307 (message_1.id, Role::User, 0..2),
2308 (message_2.id, Role::Assistant, 2..4),
2309 (message_4.id, Role::User, 4..5),
2310 (message_3.id, Role::User, 5..5),
2311 ]
2312 );
2313
2314 buffer.update(cx, |buffer, cx| {
2315 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
2316 });
2317 assert_eq!(
2318 messages(&conversation, cx),
2319 vec![
2320 (message_1.id, Role::User, 0..2),
2321 (message_2.id, Role::Assistant, 2..4),
2322 (message_4.id, Role::User, 4..6),
2323 (message_3.id, Role::User, 6..7),
2324 ]
2325 );
2326
2327 // Deleting across message boundaries merges the messages.
2328 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
2329 assert_eq!(
2330 messages(&conversation, cx),
2331 vec![
2332 (message_1.id, Role::User, 0..3),
2333 (message_3.id, Role::User, 3..4),
2334 ]
2335 );
2336
2337 // Undoing the deletion should also undo the merge.
2338 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2339 assert_eq!(
2340 messages(&conversation, cx),
2341 vec![
2342 (message_1.id, Role::User, 0..2),
2343 (message_2.id, Role::Assistant, 2..4),
2344 (message_4.id, Role::User, 4..6),
2345 (message_3.id, Role::User, 6..7),
2346 ]
2347 );
2348
2349 // Redoing the deletion should also redo the merge.
2350 buffer.update(cx, |buffer, cx| buffer.redo(cx));
2351 assert_eq!(
2352 messages(&conversation, cx),
2353 vec![
2354 (message_1.id, Role::User, 0..3),
2355 (message_3.id, Role::User, 3..4),
2356 ]
2357 );
2358
2359 // Ensure we can still insert after a merged message.
2360 let message_5 = conversation.update(cx, |conversation, cx| {
2361 conversation
2362 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2363 .unwrap()
2364 });
2365 assert_eq!(
2366 messages(&conversation, cx),
2367 vec![
2368 (message_1.id, Role::User, 0..3),
2369 (message_5.id, Role::System, 3..4),
2370 (message_3.id, Role::User, 4..5)
2371 ]
2372 );
2373 }
2374
2375 #[gpui::test]
2376 fn test_message_splitting(cx: &mut AppContext) {
2377 let registry = Arc::new(LanguageRegistry::test());
2378 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2379 let buffer = conversation.read(cx).buffer.clone();
2380
2381 let message_1 = conversation.read(cx).message_anchors[0].clone();
2382 assert_eq!(
2383 messages(&conversation, cx),
2384 vec![(message_1.id, Role::User, 0..0)]
2385 );
2386
2387 buffer.update(cx, |buffer, cx| {
2388 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
2389 });
2390
2391 let (_, message_2) =
2392 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2393 let message_2 = message_2.unwrap();
2394
2395 // We recycle newlines in the middle of a split message
2396 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
2397 assert_eq!(
2398 messages(&conversation, cx),
2399 vec![
2400 (message_1.id, Role::User, 0..4),
2401 (message_2.id, Role::User, 4..16),
2402 ]
2403 );
2404
2405 let (_, message_3) =
2406 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2407 let message_3 = message_3.unwrap();
2408
2409 // We don't recycle newlines at the end of a split message
2410 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2411 assert_eq!(
2412 messages(&conversation, cx),
2413 vec![
2414 (message_1.id, Role::User, 0..4),
2415 (message_3.id, Role::User, 4..5),
2416 (message_2.id, Role::User, 5..17),
2417 ]
2418 );
2419
2420 let (_, message_4) =
2421 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2422 let message_4 = message_4.unwrap();
2423 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2424 assert_eq!(
2425 messages(&conversation, cx),
2426 vec![
2427 (message_1.id, Role::User, 0..4),
2428 (message_3.id, Role::User, 4..5),
2429 (message_2.id, Role::User, 5..9),
2430 (message_4.id, Role::User, 9..17),
2431 ]
2432 );
2433
2434 let (_, message_5) =
2435 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2436 let message_5 = message_5.unwrap();
2437 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
2438 assert_eq!(
2439 messages(&conversation, cx),
2440 vec![
2441 (message_1.id, Role::User, 0..4),
2442 (message_3.id, Role::User, 4..5),
2443 (message_2.id, Role::User, 5..9),
2444 (message_4.id, Role::User, 9..10),
2445 (message_5.id, Role::User, 10..18),
2446 ]
2447 );
2448
2449 let (message_6, message_7) = conversation.update(cx, |conversation, cx| {
2450 conversation.split_message(14..16, cx)
2451 });
2452 let message_6 = message_6.unwrap();
2453 let message_7 = message_7.unwrap();
2454 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
2455 assert_eq!(
2456 messages(&conversation, cx),
2457 vec![
2458 (message_1.id, Role::User, 0..4),
2459 (message_3.id, Role::User, 4..5),
2460 (message_2.id, Role::User, 5..9),
2461 (message_4.id, Role::User, 9..10),
2462 (message_5.id, Role::User, 10..14),
2463 (message_6.id, Role::User, 14..17),
2464 (message_7.id, Role::User, 17..19),
2465 ]
2466 );
2467 }
2468
2469 #[gpui::test]
2470 fn test_messages_for_offsets(cx: &mut AppContext) {
2471 let registry = Arc::new(LanguageRegistry::test());
2472 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2473 let buffer = conversation.read(cx).buffer.clone();
2474
2475 let message_1 = conversation.read(cx).message_anchors[0].clone();
2476 assert_eq!(
2477 messages(&conversation, cx),
2478 vec![(message_1.id, Role::User, 0..0)]
2479 );
2480
2481 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
2482 let message_2 = conversation
2483 .update(cx, |conversation, cx| {
2484 conversation.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
2485 })
2486 .unwrap();
2487 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
2488
2489 let message_3 = conversation
2490 .update(cx, |conversation, cx| {
2491 conversation.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2492 })
2493 .unwrap();
2494 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
2495
2496 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
2497 assert_eq!(
2498 messages(&conversation, cx),
2499 vec![
2500 (message_1.id, Role::User, 0..4),
2501 (message_2.id, Role::User, 4..8),
2502 (message_3.id, Role::User, 8..11)
2503 ]
2504 );
2505
2506 assert_eq!(
2507 message_ids_for_offsets(&conversation, &[0, 4, 9], cx),
2508 [message_1.id, message_2.id, message_3.id]
2509 );
2510 assert_eq!(
2511 message_ids_for_offsets(&conversation, &[0, 1, 11], cx),
2512 [message_1.id, message_3.id]
2513 );
2514
2515 let message_4 = conversation
2516 .update(cx, |conversation, cx| {
2517 conversation.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
2518 })
2519 .unwrap();
2520 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
2521 assert_eq!(
2522 messages(&conversation, cx),
2523 vec![
2524 (message_1.id, Role::User, 0..4),
2525 (message_2.id, Role::User, 4..8),
2526 (message_3.id, Role::User, 8..12),
2527 (message_4.id, Role::User, 12..12)
2528 ]
2529 );
2530 assert_eq!(
2531 message_ids_for_offsets(&conversation, &[0, 4, 8, 12], cx),
2532 [message_1.id, message_2.id, message_3.id, message_4.id]
2533 );
2534
2535 fn message_ids_for_offsets(
2536 conversation: &ModelHandle<Conversation>,
2537 offsets: &[usize],
2538 cx: &AppContext,
2539 ) -> Vec<MessageId> {
2540 conversation
2541 .read(cx)
2542 .messages_for_offsets(offsets.iter().copied(), cx)
2543 .into_iter()
2544 .map(|message| message.id)
2545 .collect()
2546 }
2547 }
2548
2549 #[gpui::test]
2550 fn test_serialization(cx: &mut AppContext) {
2551 let registry = Arc::new(LanguageRegistry::test());
2552 let conversation =
2553 cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
2554 let buffer = conversation.read(cx).buffer.clone();
2555 let message_0 = conversation.read(cx).message_anchors[0].id;
2556 let message_1 = conversation.update(cx, |conversation, cx| {
2557 conversation
2558 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
2559 .unwrap()
2560 });
2561 let message_2 = conversation.update(cx, |conversation, cx| {
2562 conversation
2563 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2564 .unwrap()
2565 });
2566 buffer.update(cx, |buffer, cx| {
2567 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
2568 buffer.finalize_last_transaction();
2569 });
2570 let _message_3 = conversation.update(cx, |conversation, cx| {
2571 conversation
2572 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
2573 .unwrap()
2574 });
2575 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2576 assert_eq!(buffer.read(cx).text(), "a\nb\nc\n");
2577 assert_eq!(
2578 messages(&conversation, cx),
2579 [
2580 (message_0, Role::User, 0..2),
2581 (message_1.id, Role::Assistant, 2..6),
2582 (message_2.id, Role::System, 6..6),
2583 ]
2584 );
2585
2586 let deserialized_conversation = cx.add_model(|cx| {
2587 Conversation::deserialize(
2588 conversation.read(cx).serialize(cx),
2589 Default::default(),
2590 Default::default(),
2591 registry.clone(),
2592 cx,
2593 )
2594 });
2595 let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone();
2596 assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n");
2597 assert_eq!(
2598 messages(&deserialized_conversation, cx),
2599 [
2600 (message_0, Role::User, 0..2),
2601 (message_1.id, Role::Assistant, 2..6),
2602 (message_2.id, Role::System, 6..6),
2603 ]
2604 );
2605 }
2606
2607 fn messages(
2608 conversation: &ModelHandle<Conversation>,
2609 cx: &AppContext,
2610 ) -> Vec<(MessageId, Role, Range<usize>)> {
2611 conversation
2612 .read(cx)
2613 .messages(cx)
2614 .map(|message| (message.id, message.role, message.offset_range))
2615 .collect()
2616 }
2617}