1use crate::{
2 assistant_settings::{AssistantDockPosition, AssistantSettings},
3 MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
4 RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
5};
6use anyhow::{anyhow, Result};
7use chrono::{DateTime, Local};
8use collections::{HashMap, HashSet};
9use editor::{
10 display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
11 scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
12 Anchor, Editor, ToOffset,
13};
14use fs::Fs;
15use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
16use gpui::{
17 actions,
18 elements::*,
19 executor::Background,
20 geometry::vector::{vec2f, Vector2F},
21 platform::{CursorStyle, MouseButton},
22 Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
23 Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
24};
25use isahc::{http::StatusCode, Request, RequestExt};
26use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
27use search::BufferSearchBar;
28use serde::Deserialize;
29use settings::SettingsStore;
30use std::{
31 cell::RefCell,
32 cmp, env,
33 fmt::Write,
34 io, iter,
35 ops::Range,
36 path::{Path, PathBuf},
37 rc::Rc,
38 sync::Arc,
39 time::Duration,
40};
41use theme::AssistantStyle;
42use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
43use workspace::{
44 dock::{DockPosition, Panel},
45 searchable::Direction,
46 Save, ToggleZoom, Toolbar, Workspace,
47};
48
49const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
50
51actions!(
52 assistant,
53 [
54 NewConversation,
55 Assist,
56 Split,
57 CycleMessageRole,
58 QuoteSelection,
59 ToggleFocus,
60 ResetKey,
61 ]
62);
63
64pub fn init(cx: &mut AppContext) {
65 settings::register::<AssistantSettings>(cx);
66 cx.add_action(
67 |this: &mut AssistantPanel,
68 _: &workspace::NewFile,
69 cx: &mut ViewContext<AssistantPanel>| {
70 this.new_conversation(cx);
71 },
72 );
73 cx.add_action(ConversationEditor::assist);
74 cx.capture_action(ConversationEditor::cancel_last_assist);
75 cx.capture_action(ConversationEditor::save);
76 cx.add_action(ConversationEditor::quote_selection);
77 cx.capture_action(ConversationEditor::copy);
78 cx.add_action(ConversationEditor::split);
79 cx.capture_action(ConversationEditor::cycle_message_role);
80 cx.add_action(AssistantPanel::save_api_key);
81 cx.add_action(AssistantPanel::reset_api_key);
82 cx.add_action(AssistantPanel::toggle_zoom);
83 cx.add_action(AssistantPanel::deploy);
84 cx.add_action(AssistantPanel::select_next_match);
85 cx.add_action(AssistantPanel::select_prev_match);
86 cx.add_action(AssistantPanel::handle_editor_cancel);
87 cx.add_action(
88 |workspace: &mut Workspace, _: &ToggleFocus, cx: &mut ViewContext<Workspace>| {
89 workspace.toggle_panel_focus::<AssistantPanel>(cx);
90 },
91 );
92}
93
94#[derive(Debug)]
95pub enum AssistantPanelEvent {
96 ZoomIn,
97 ZoomOut,
98 Focus,
99 Close,
100 DockPositionChanged,
101}
102
103pub struct AssistantPanel {
104 workspace: WeakViewHandle<Workspace>,
105 width: Option<f32>,
106 height: Option<f32>,
107 active_editor_index: Option<usize>,
108 prev_active_editor_index: Option<usize>,
109 editors: Vec<ViewHandle<ConversationEditor>>,
110 saved_conversations: Vec<SavedConversationMetadata>,
111 saved_conversations_list_state: UniformListState,
112 zoomed: bool,
113 has_focus: bool,
114 toolbar: ViewHandle<Toolbar>,
115 api_key: Rc<RefCell<Option<String>>>,
116 api_key_editor: Option<ViewHandle<Editor>>,
117 has_read_credentials: bool,
118 languages: Arc<LanguageRegistry>,
119 fs: Arc<dyn Fs>,
120 subscriptions: Vec<Subscription>,
121 _watch_saved_conversations: Task<Result<()>>,
122}
123
124impl AssistantPanel {
125 pub fn load(
126 workspace: WeakViewHandle<Workspace>,
127 cx: AsyncAppContext,
128 ) -> Task<Result<ViewHandle<Self>>> {
129 cx.spawn(|mut cx| async move {
130 let fs = workspace.read_with(&cx, |workspace, _| workspace.app_state().fs.clone())?;
131 let saved_conversations = SavedConversationMetadata::list(fs.clone())
132 .await
133 .log_err()
134 .unwrap_or_default();
135
136 // TODO: deserialize state.
137 let workspace_handle = workspace.clone();
138 workspace.update(&mut cx, |workspace, cx| {
139 cx.add_view::<Self, _>(|cx| {
140 const CONVERSATION_WATCH_DURATION: Duration = Duration::from_millis(100);
141 let _watch_saved_conversations = cx.spawn(move |this, mut cx| async move {
142 let mut events = fs
143 .watch(&CONVERSATIONS_DIR, CONVERSATION_WATCH_DURATION)
144 .await;
145 while events.next().await.is_some() {
146 let saved_conversations = SavedConversationMetadata::list(fs.clone())
147 .await
148 .log_err()
149 .unwrap_or_default();
150 this.update(&mut cx, |this, cx| {
151 this.saved_conversations = saved_conversations;
152 cx.notify();
153 })
154 .ok();
155 }
156
157 anyhow::Ok(())
158 });
159
160 let toolbar = cx.add_view(|cx| {
161 let mut toolbar = Toolbar::new(None);
162 toolbar.set_can_navigate(false, cx);
163 toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
164 toolbar
165 });
166 let mut this = Self {
167 workspace: workspace_handle,
168 active_editor_index: Default::default(),
169 prev_active_editor_index: Default::default(),
170 editors: Default::default(),
171 saved_conversations,
172 saved_conversations_list_state: Default::default(),
173 zoomed: false,
174 has_focus: false,
175 toolbar,
176 api_key: Rc::new(RefCell::new(None)),
177 api_key_editor: None,
178 has_read_credentials: false,
179 languages: workspace.app_state().languages.clone(),
180 fs: workspace.app_state().fs.clone(),
181 width: None,
182 height: None,
183 subscriptions: Default::default(),
184 _watch_saved_conversations,
185 };
186
187 let mut old_dock_position = this.position(cx);
188 this.subscriptions =
189 vec![cx.observe_global::<SettingsStore, _>(move |this, cx| {
190 let new_dock_position = this.position(cx);
191 if new_dock_position != old_dock_position {
192 old_dock_position = new_dock_position;
193 cx.emit(AssistantPanelEvent::DockPositionChanged);
194 }
195 })];
196
197 this
198 })
199 })
200 })
201 }
202
203 fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
204 let editor = cx.add_view(|cx| {
205 ConversationEditor::new(
206 self.api_key.clone(),
207 self.languages.clone(),
208 self.fs.clone(),
209 cx,
210 )
211 });
212 self.add_conversation(editor.clone(), cx);
213 editor
214 }
215
216 fn add_conversation(
217 &mut self,
218 editor: ViewHandle<ConversationEditor>,
219 cx: &mut ViewContext<Self>,
220 ) {
221 self.subscriptions
222 .push(cx.subscribe(&editor, Self::handle_conversation_editor_event));
223
224 let conversation = editor.read(cx).conversation.clone();
225 self.subscriptions
226 .push(cx.observe(&conversation, |_, _, cx| cx.notify()));
227
228 let index = self.editors.len();
229 self.editors.push(editor);
230 self.set_active_editor_index(Some(index), cx);
231 }
232
233 fn set_active_editor_index(&mut self, index: Option<usize>, cx: &mut ViewContext<Self>) {
234 self.prev_active_editor_index = self.active_editor_index;
235 self.active_editor_index = index;
236 if let Some(editor) = self.active_editor() {
237 let editor = editor.read(cx).editor.clone();
238 self.toolbar.update(cx, |toolbar, cx| {
239 toolbar.set_active_item(Some(&editor), cx);
240 });
241 if self.has_focus(cx) {
242 cx.focus(&editor);
243 }
244 } else {
245 self.toolbar.update(cx, |toolbar, cx| {
246 toolbar.set_active_item(None, cx);
247 });
248 }
249
250 cx.notify();
251 }
252
253 fn handle_conversation_editor_event(
254 &mut self,
255 _: ViewHandle<ConversationEditor>,
256 event: &ConversationEditorEvent,
257 cx: &mut ViewContext<Self>,
258 ) {
259 match event {
260 ConversationEditorEvent::TabContentChanged => cx.notify(),
261 }
262 }
263
264 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
265 if let Some(api_key) = self
266 .api_key_editor
267 .as_ref()
268 .map(|editor| editor.read(cx).text(cx))
269 {
270 if !api_key.is_empty() {
271 cx.platform()
272 .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
273 .log_err();
274 *self.api_key.borrow_mut() = Some(api_key);
275 self.api_key_editor.take();
276 cx.focus_self();
277 cx.notify();
278 }
279 } else {
280 cx.propagate_action();
281 }
282 }
283
284 fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
285 cx.platform().delete_credentials(OPENAI_API_URL).log_err();
286 self.api_key.take();
287 self.api_key_editor = Some(build_api_key_editor(cx));
288 cx.focus_self();
289 cx.notify();
290 }
291
292 fn toggle_zoom(&mut self, _: &workspace::ToggleZoom, cx: &mut ViewContext<Self>) {
293 if self.zoomed {
294 cx.emit(AssistantPanelEvent::ZoomOut)
295 } else {
296 cx.emit(AssistantPanelEvent::ZoomIn)
297 }
298 }
299
300 fn deploy(&mut self, action: &search::buffer_search::Deploy, cx: &mut ViewContext<Self>) {
301 let mut propagate_action = true;
302 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
303 search_bar.update(cx, |search_bar, cx| {
304 if search_bar.show(cx) {
305 search_bar.search_suggested(cx);
306 if action.focus {
307 search_bar.select_query(cx);
308 cx.focus_self();
309 }
310 propagate_action = false
311 }
312 });
313 }
314 if propagate_action {
315 cx.propagate_action();
316 }
317 }
318
319 fn handle_editor_cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
320 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
321 if !search_bar.read(cx).is_dismissed() {
322 search_bar.update(cx, |search_bar, cx| {
323 search_bar.dismiss(&Default::default(), cx)
324 });
325 return;
326 }
327 }
328 cx.propagate_action();
329 }
330
331 fn select_next_match(&mut self, _: &search::SelectNextMatch, cx: &mut ViewContext<Self>) {
332 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
333 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Next, 1, cx));
334 }
335 }
336
337 fn select_prev_match(&mut self, _: &search::SelectPrevMatch, cx: &mut ViewContext<Self>) {
338 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
339 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Prev, 1, cx));
340 }
341 }
342
343 fn active_editor(&self) -> Option<&ViewHandle<ConversationEditor>> {
344 self.editors.get(self.active_editor_index?)
345 }
346
347 fn render_hamburger_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
348 enum History {}
349 let theme = theme::current(cx);
350 let tooltip_style = theme::current(cx).tooltip.clone();
351 MouseEventHandler::<History, _>::new(0, cx, |state, _| {
352 let style = theme.assistant.hamburger_button.style_for(state);
353 Svg::for_style(style.icon.clone())
354 .contained()
355 .with_style(style.container)
356 })
357 .with_cursor_style(CursorStyle::PointingHand)
358 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
359 if this.active_editor().is_some() {
360 this.set_active_editor_index(None, cx);
361 } else {
362 this.set_active_editor_index(this.prev_active_editor_index, cx);
363 }
364 })
365 .with_tooltip::<History>(1, "History".into(), None, tooltip_style, cx)
366 }
367
368 fn render_editor_tools(&self, cx: &mut ViewContext<Self>) -> Vec<AnyElement<Self>> {
369 if self.active_editor().is_some() {
370 vec![
371 Self::render_split_button(cx).into_any(),
372 Self::render_quote_button(cx).into_any(),
373 Self::render_assist_button(cx).into_any(),
374 ]
375 } else {
376 Default::default()
377 }
378 }
379
380 fn render_split_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
381 let theme = theme::current(cx);
382 let tooltip_style = theme::current(cx).tooltip.clone();
383 MouseEventHandler::<Split, _>::new(0, cx, |state, _| {
384 let style = theme.assistant.split_button.style_for(state);
385 Svg::for_style(style.icon.clone())
386 .contained()
387 .with_style(style.container)
388 })
389 .with_cursor_style(CursorStyle::PointingHand)
390 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
391 if let Some(active_editor) = this.active_editor() {
392 active_editor.update(cx, |editor, cx| editor.split(&Default::default(), cx));
393 }
394 })
395 .with_tooltip::<Split>(
396 1,
397 "Split Message".into(),
398 Some(Box::new(Split)),
399 tooltip_style,
400 cx,
401 )
402 }
403
404 fn render_assist_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
405 let theme = theme::current(cx);
406 let tooltip_style = theme::current(cx).tooltip.clone();
407 MouseEventHandler::<Assist, _>::new(0, cx, |state, _| {
408 let style = theme.assistant.assist_button.style_for(state);
409 Svg::for_style(style.icon.clone())
410 .contained()
411 .with_style(style.container)
412 })
413 .with_cursor_style(CursorStyle::PointingHand)
414 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
415 if let Some(active_editor) = this.active_editor() {
416 active_editor.update(cx, |editor, cx| editor.assist(&Default::default(), cx));
417 }
418 })
419 .with_tooltip::<Assist>(
420 1,
421 "Assist".into(),
422 Some(Box::new(Assist)),
423 tooltip_style,
424 cx,
425 )
426 }
427
428 fn render_quote_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
429 let theme = theme::current(cx);
430 let tooltip_style = theme::current(cx).tooltip.clone();
431 MouseEventHandler::<QuoteSelection, _>::new(0, cx, |state, _| {
432 let style = theme.assistant.quote_button.style_for(state);
433 Svg::for_style(style.icon.clone())
434 .contained()
435 .with_style(style.container)
436 })
437 .with_cursor_style(CursorStyle::PointingHand)
438 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
439 if let Some(workspace) = this.workspace.upgrade(cx) {
440 cx.window_context().defer(move |cx| {
441 workspace.update(cx, |workspace, cx| {
442 ConversationEditor::quote_selection(workspace, &Default::default(), cx)
443 });
444 });
445 }
446 })
447 .with_tooltip::<QuoteSelection>(
448 1,
449 "Quote Selection".into(),
450 Some(Box::new(QuoteSelection)),
451 tooltip_style,
452 cx,
453 )
454 }
455
456 fn render_plus_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
457 let theme = theme::current(cx);
458 let tooltip_style = theme::current(cx).tooltip.clone();
459 MouseEventHandler::<NewConversation, _>::new(0, cx, |state, _| {
460 let style = theme.assistant.plus_button.style_for(state);
461 Svg::for_style(style.icon.clone())
462 .contained()
463 .with_style(style.container)
464 })
465 .with_cursor_style(CursorStyle::PointingHand)
466 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
467 this.new_conversation(cx);
468 })
469 .with_tooltip::<NewConversation>(
470 1,
471 "New Conversation".into(),
472 Some(Box::new(NewConversation)),
473 tooltip_style,
474 cx,
475 )
476 }
477
478 fn render_zoom_button(&self, cx: &mut ViewContext<Self>) -> impl Element<Self> {
479 enum ToggleZoomButton {}
480
481 let theme = theme::current(cx);
482 let tooltip_style = theme::current(cx).tooltip.clone();
483 let style = if self.zoomed {
484 &theme.assistant.zoom_out_button
485 } else {
486 &theme.assistant.zoom_in_button
487 };
488
489 MouseEventHandler::<ToggleZoomButton, _>::new(0, cx, |state, _| {
490 let style = style.style_for(state);
491 Svg::for_style(style.icon.clone())
492 .contained()
493 .with_style(style.container)
494 })
495 .with_cursor_style(CursorStyle::PointingHand)
496 .on_click(MouseButton::Left, |_, this, cx| {
497 this.toggle_zoom(&ToggleZoom, cx);
498 })
499 .with_tooltip::<ToggleZoom>(
500 0,
501 if self.zoomed {
502 "Zoom Out".into()
503 } else {
504 "Zoom In".into()
505 },
506 Some(Box::new(ToggleZoom)),
507 tooltip_style,
508 cx,
509 )
510 }
511
512 fn render_saved_conversation(
513 &mut self,
514 index: usize,
515 cx: &mut ViewContext<Self>,
516 ) -> impl Element<Self> {
517 let conversation = &self.saved_conversations[index];
518 let path = conversation.path.clone();
519 MouseEventHandler::<SavedConversationMetadata, _>::new(index, cx, move |state, cx| {
520 let style = &theme::current(cx).assistant.saved_conversation;
521 Flex::row()
522 .with_child(
523 Label::new(
524 conversation.mtime.format("%F %I:%M%p").to_string(),
525 style.saved_at.text.clone(),
526 )
527 .aligned()
528 .contained()
529 .with_style(style.saved_at.container),
530 )
531 .with_child(
532 Label::new(conversation.title.clone(), style.title.text.clone())
533 .aligned()
534 .contained()
535 .with_style(style.title.container),
536 )
537 .contained()
538 .with_style(*style.container.style_for(state))
539 })
540 .with_cursor_style(CursorStyle::PointingHand)
541 .on_click(MouseButton::Left, move |_, this, cx| {
542 this.open_conversation(path.clone(), cx)
543 .detach_and_log_err(cx)
544 })
545 }
546
547 fn open_conversation(&mut self, path: PathBuf, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
548 if let Some(ix) = self.editor_index_for_path(&path, cx) {
549 self.set_active_editor_index(Some(ix), cx);
550 return Task::ready(Ok(()));
551 }
552
553 let fs = self.fs.clone();
554 let api_key = self.api_key.clone();
555 let languages = self.languages.clone();
556 cx.spawn(|this, mut cx| async move {
557 let saved_conversation = fs.load(&path).await?;
558 let saved_conversation = serde_json::from_str(&saved_conversation)?;
559 let conversation = cx.add_model(|cx| {
560 Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
561 });
562 this.update(&mut cx, |this, cx| {
563 // If, by the time we've loaded the conversation, the user has already opened
564 // the same conversation, we don't want to open it again.
565 if let Some(ix) = this.editor_index_for_path(&path, cx) {
566 this.set_active_editor_index(Some(ix), cx);
567 } else {
568 let editor = cx
569 .add_view(|cx| ConversationEditor::for_conversation(conversation, fs, cx));
570 this.add_conversation(editor, cx);
571 }
572 })?;
573 Ok(())
574 })
575 }
576
577 fn editor_index_for_path(&self, path: &Path, cx: &AppContext) -> Option<usize> {
578 self.editors
579 .iter()
580 .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
581 }
582}
583
584fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
585 cx.add_view(|cx| {
586 let mut editor = Editor::single_line(
587 Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
588 cx,
589 );
590 editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
591 editor
592 })
593}
594
595impl Entity for AssistantPanel {
596 type Event = AssistantPanelEvent;
597}
598
599impl View for AssistantPanel {
600 fn ui_name() -> &'static str {
601 "AssistantPanel"
602 }
603
604 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
605 let theme = &theme::current(cx);
606 let style = &theme.assistant;
607 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
608 Flex::column()
609 .with_child(
610 Text::new(
611 "Paste your OpenAI API key and press Enter to use the assistant",
612 style.api_key_prompt.text.clone(),
613 )
614 .aligned(),
615 )
616 .with_child(
617 ChildView::new(api_key_editor, cx)
618 .contained()
619 .with_style(style.api_key_editor.container)
620 .aligned(),
621 )
622 .contained()
623 .with_style(style.api_key_prompt.container)
624 .aligned()
625 .into_any()
626 } else {
627 let title = self.active_editor().map(|editor| {
628 Label::new(editor.read(cx).title(cx), style.title.text.clone())
629 .contained()
630 .with_style(style.title.container)
631 .aligned()
632 .left()
633 .flex(1., false)
634 });
635 let mut header = Flex::row()
636 .with_child(Self::render_hamburger_button(cx).aligned())
637 .with_children(title);
638 if self.has_focus {
639 header.add_children(
640 self.render_editor_tools(cx)
641 .into_iter()
642 .map(|tool| tool.aligned().flex_float()),
643 );
644 header.add_child(Self::render_plus_button(cx).aligned().flex_float());
645 header.add_child(self.render_zoom_button(cx).aligned());
646 }
647
648 Flex::column()
649 .with_child(
650 header
651 .contained()
652 .with_style(theme.workspace.tab_bar.container)
653 .expanded()
654 .constrained()
655 .with_height(theme.workspace.tab_bar.height),
656 )
657 .with_children(if self.toolbar.read(cx).hidden() {
658 None
659 } else {
660 Some(ChildView::new(&self.toolbar, cx).expanded())
661 })
662 .with_child(if let Some(editor) = self.active_editor() {
663 ChildView::new(editor, cx).flex(1., true).into_any()
664 } else {
665 UniformList::new(
666 self.saved_conversations_list_state.clone(),
667 self.saved_conversations.len(),
668 cx,
669 |this, range, items, cx| {
670 for ix in range {
671 items.push(this.render_saved_conversation(ix, cx).into_any());
672 }
673 },
674 )
675 .flex(1., true)
676 .into_any()
677 })
678 .into_any()
679 }
680 }
681
682 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
683 self.has_focus = true;
684 self.toolbar
685 .update(cx, |toolbar, cx| toolbar.focus_changed(true, cx));
686 cx.notify();
687 if cx.is_self_focused() {
688 if let Some(editor) = self.active_editor() {
689 cx.focus(editor);
690 } else if let Some(api_key_editor) = self.api_key_editor.as_ref() {
691 cx.focus(api_key_editor);
692 }
693 }
694 }
695
696 fn focus_out(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
697 self.has_focus = false;
698 self.toolbar
699 .update(cx, |toolbar, cx| toolbar.focus_changed(false, cx));
700 cx.notify();
701 }
702}
703
704impl Panel for AssistantPanel {
705 fn position(&self, cx: &WindowContext) -> DockPosition {
706 match settings::get::<AssistantSettings>(cx).dock {
707 AssistantDockPosition::Left => DockPosition::Left,
708 AssistantDockPosition::Bottom => DockPosition::Bottom,
709 AssistantDockPosition::Right => DockPosition::Right,
710 }
711 }
712
713 fn position_is_valid(&self, _: DockPosition) -> bool {
714 true
715 }
716
717 fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
718 settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
719 let dock = match position {
720 DockPosition::Left => AssistantDockPosition::Left,
721 DockPosition::Bottom => AssistantDockPosition::Bottom,
722 DockPosition::Right => AssistantDockPosition::Right,
723 };
724 settings.dock = Some(dock);
725 });
726 }
727
728 fn size(&self, cx: &WindowContext) -> f32 {
729 let settings = settings::get::<AssistantSettings>(cx);
730 match self.position(cx) {
731 DockPosition::Left | DockPosition::Right => {
732 self.width.unwrap_or_else(|| settings.default_width)
733 }
734 DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
735 }
736 }
737
738 fn set_size(&mut self, size: f32, cx: &mut ViewContext<Self>) {
739 match self.position(cx) {
740 DockPosition::Left | DockPosition::Right => self.width = Some(size),
741 DockPosition::Bottom => self.height = Some(size),
742 }
743 cx.notify();
744 }
745
746 fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
747 matches!(event, AssistantPanelEvent::ZoomIn)
748 }
749
750 fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
751 matches!(event, AssistantPanelEvent::ZoomOut)
752 }
753
754 fn is_zoomed(&self, _: &WindowContext) -> bool {
755 self.zoomed
756 }
757
758 fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
759 self.zoomed = zoomed;
760 cx.notify();
761 }
762
763 fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
764 if active {
765 if self.api_key.borrow().is_none() && !self.has_read_credentials {
766 self.has_read_credentials = true;
767 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
768 Some(api_key)
769 } else if let Some((_, api_key)) = cx
770 .platform()
771 .read_credentials(OPENAI_API_URL)
772 .log_err()
773 .flatten()
774 {
775 String::from_utf8(api_key).log_err()
776 } else {
777 None
778 };
779 if let Some(api_key) = api_key {
780 *self.api_key.borrow_mut() = Some(api_key);
781 } else if self.api_key_editor.is_none() {
782 self.api_key_editor = Some(build_api_key_editor(cx));
783 cx.notify();
784 }
785 }
786
787 if self.editors.is_empty() {
788 self.new_conversation(cx);
789 }
790 }
791 }
792
793 fn icon_path(&self) -> &'static str {
794 "icons/robot_14.svg"
795 }
796
797 fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
798 ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
799 }
800
801 fn should_change_position_on_event(event: &Self::Event) -> bool {
802 matches!(event, AssistantPanelEvent::DockPositionChanged)
803 }
804
805 fn should_activate_on_event(_: &Self::Event) -> bool {
806 false
807 }
808
809 fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
810 matches!(event, AssistantPanelEvent::Close)
811 }
812
813 fn has_focus(&self, _: &WindowContext) -> bool {
814 self.has_focus
815 }
816
817 fn is_focus_event(event: &Self::Event) -> bool {
818 matches!(event, AssistantPanelEvent::Focus)
819 }
820}
821
822enum ConversationEvent {
823 MessagesEdited,
824 SummaryChanged,
825 StreamedCompletion,
826}
827
828#[derive(Default)]
829struct Summary {
830 text: String,
831 done: bool,
832}
833
834struct Conversation {
835 buffer: ModelHandle<Buffer>,
836 message_anchors: Vec<MessageAnchor>,
837 messages_metadata: HashMap<MessageId, MessageMetadata>,
838 next_message_id: MessageId,
839 summary: Option<Summary>,
840 pending_summary: Task<Option<()>>,
841 completion_count: usize,
842 pending_completions: Vec<PendingCompletion>,
843 model: String,
844 token_count: Option<usize>,
845 max_token_count: usize,
846 pending_token_count: Task<Option<()>>,
847 api_key: Rc<RefCell<Option<String>>>,
848 pending_save: Task<Result<()>>,
849 path: Option<PathBuf>,
850 _subscriptions: Vec<Subscription>,
851}
852
853impl Entity for Conversation {
854 type Event = ConversationEvent;
855}
856
857impl Conversation {
858 fn new(
859 api_key: Rc<RefCell<Option<String>>>,
860 language_registry: Arc<LanguageRegistry>,
861 cx: &mut ModelContext<Self>,
862 ) -> Self {
863 let model = "gpt-3.5-turbo-0613";
864 let markdown = language_registry.language_for_name("Markdown");
865 let buffer = cx.add_model(|cx| {
866 let mut buffer = Buffer::new(0, "", cx);
867 buffer.set_language_registry(language_registry);
868 cx.spawn_weak(|buffer, mut cx| async move {
869 let markdown = markdown.await?;
870 let buffer = buffer
871 .upgrade(&cx)
872 .ok_or_else(|| anyhow!("buffer was dropped"))?;
873 buffer.update(&mut cx, |buffer, cx| {
874 buffer.set_language(Some(markdown), cx)
875 });
876 anyhow::Ok(())
877 })
878 .detach_and_log_err(cx);
879 buffer
880 });
881
882 let mut this = Self {
883 message_anchors: Default::default(),
884 messages_metadata: Default::default(),
885 next_message_id: Default::default(),
886 summary: None,
887 pending_summary: Task::ready(None),
888 completion_count: Default::default(),
889 pending_completions: Default::default(),
890 token_count: None,
891 max_token_count: tiktoken_rs::model::get_context_size(model),
892 pending_token_count: Task::ready(None),
893 model: model.into(),
894 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
895 pending_save: Task::ready(Ok(())),
896 path: None,
897 api_key,
898 buffer,
899 };
900 let message = MessageAnchor {
901 id: MessageId(post_inc(&mut this.next_message_id.0)),
902 start: language::Anchor::MIN,
903 };
904 this.message_anchors.push(message.clone());
905 this.messages_metadata.insert(
906 message.id,
907 MessageMetadata {
908 role: Role::User,
909 sent_at: Local::now(),
910 status: MessageStatus::Done,
911 },
912 );
913
914 this.count_remaining_tokens(cx);
915 this
916 }
917
918 fn serialize(&self, cx: &AppContext) -> SavedConversation {
919 SavedConversation {
920 zed: "conversation".into(),
921 version: SavedConversation::VERSION.into(),
922 text: self.buffer.read(cx).text(),
923 message_metadata: self.messages_metadata.clone(),
924 messages: self
925 .messages(cx)
926 .map(|message| SavedMessage {
927 id: message.id,
928 start: message.offset_range.start,
929 })
930 .collect(),
931 summary: self
932 .summary
933 .as_ref()
934 .map(|summary| summary.text.clone())
935 .unwrap_or_default(),
936 model: self.model.clone(),
937 }
938 }
939
940 fn deserialize(
941 saved_conversation: SavedConversation,
942 path: PathBuf,
943 api_key: Rc<RefCell<Option<String>>>,
944 language_registry: Arc<LanguageRegistry>,
945 cx: &mut ModelContext<Self>,
946 ) -> Self {
947 let model = saved_conversation.model;
948 let markdown = language_registry.language_for_name("Markdown");
949 let mut message_anchors = Vec::new();
950 let mut next_message_id = MessageId(0);
951 let buffer = cx.add_model(|cx| {
952 let mut buffer = Buffer::new(0, saved_conversation.text, cx);
953 for message in saved_conversation.messages {
954 message_anchors.push(MessageAnchor {
955 id: message.id,
956 start: buffer.anchor_before(message.start),
957 });
958 next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
959 }
960 buffer.set_language_registry(language_registry);
961 cx.spawn_weak(|buffer, mut cx| async move {
962 let markdown = markdown.await?;
963 let buffer = buffer
964 .upgrade(&cx)
965 .ok_or_else(|| anyhow!("buffer was dropped"))?;
966 buffer.update(&mut cx, |buffer, cx| {
967 buffer.set_language(Some(markdown), cx)
968 });
969 anyhow::Ok(())
970 })
971 .detach_and_log_err(cx);
972 buffer
973 });
974
975 let mut this = Self {
976 message_anchors,
977 messages_metadata: saved_conversation.message_metadata,
978 next_message_id,
979 summary: Some(Summary {
980 text: saved_conversation.summary,
981 done: true,
982 }),
983 pending_summary: Task::ready(None),
984 completion_count: Default::default(),
985 pending_completions: Default::default(),
986 token_count: None,
987 max_token_count: tiktoken_rs::model::get_context_size(&model),
988 pending_token_count: Task::ready(None),
989 model,
990 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
991 pending_save: Task::ready(Ok(())),
992 path: Some(path),
993 api_key,
994 buffer,
995 };
996 this.count_remaining_tokens(cx);
997 this
998 }
999
1000 fn handle_buffer_event(
1001 &mut self,
1002 _: ModelHandle<Buffer>,
1003 event: &language::Event,
1004 cx: &mut ModelContext<Self>,
1005 ) {
1006 match event {
1007 language::Event::Edited => {
1008 self.count_remaining_tokens(cx);
1009 cx.emit(ConversationEvent::MessagesEdited);
1010 }
1011 _ => {}
1012 }
1013 }
1014
1015 fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
1016 let messages = self
1017 .messages(cx)
1018 .into_iter()
1019 .filter_map(|message| {
1020 Some(tiktoken_rs::ChatCompletionRequestMessage {
1021 role: match message.role {
1022 Role::User => "user".into(),
1023 Role::Assistant => "assistant".into(),
1024 Role::System => "system".into(),
1025 },
1026 content: self
1027 .buffer
1028 .read(cx)
1029 .text_for_range(message.offset_range)
1030 .collect(),
1031 name: None,
1032 })
1033 })
1034 .collect::<Vec<_>>();
1035 let model = self.model.clone();
1036 self.pending_token_count = cx.spawn_weak(|this, mut cx| {
1037 async move {
1038 cx.background().timer(Duration::from_millis(200)).await;
1039 let token_count = cx
1040 .background()
1041 .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
1042 .await?;
1043
1044 this.upgrade(&cx)
1045 .ok_or_else(|| anyhow!("conversation was dropped"))?
1046 .update(&mut cx, |this, cx| {
1047 this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
1048 this.token_count = Some(token_count);
1049 cx.notify()
1050 });
1051 anyhow::Ok(())
1052 }
1053 .log_err()
1054 });
1055 }
1056
1057 fn remaining_tokens(&self) -> Option<isize> {
1058 Some(self.max_token_count as isize - self.token_count? as isize)
1059 }
1060
1061 fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
1062 self.model = model;
1063 self.count_remaining_tokens(cx);
1064 cx.notify();
1065 }
1066
1067 fn assist(
1068 &mut self,
1069 selected_messages: HashSet<MessageId>,
1070 cx: &mut ModelContext<Self>,
1071 ) -> Vec<MessageAnchor> {
1072 let mut user_messages = Vec::new();
1073 let mut tasks = Vec::new();
1074
1075 let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
1076 message
1077 .start
1078 .is_valid(self.buffer.read(cx))
1079 .then_some(message.id)
1080 });
1081
1082 for selected_message_id in selected_messages {
1083 let selected_message_role =
1084 if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
1085 metadata.role
1086 } else {
1087 continue;
1088 };
1089
1090 if selected_message_role == Role::Assistant {
1091 if let Some(user_message) = self.insert_message_after(
1092 selected_message_id,
1093 Role::User,
1094 MessageStatus::Done,
1095 cx,
1096 ) {
1097 user_messages.push(user_message);
1098 } else {
1099 continue;
1100 }
1101 } else {
1102 let request = OpenAIRequest {
1103 model: self.model.clone(),
1104 messages: self
1105 .messages(cx)
1106 .filter(|message| matches!(message.status, MessageStatus::Done))
1107 .flat_map(|message| {
1108 let mut system_message = None;
1109 if message.id == selected_message_id {
1110 system_message = Some(RequestMessage {
1111 role: Role::System,
1112 content: concat!(
1113 "Treat the following messages as additional knowledge you have learned about, ",
1114 "but act as if they were not part of this conversation. That is, treat them ",
1115 "as if the user didn't see them and couldn't possibly inquire about them."
1116 ).into()
1117 });
1118 }
1119
1120 Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
1121 })
1122 .chain(Some(RequestMessage {
1123 role: Role::System,
1124 content: format!(
1125 "Direct your reply to message with id {}. Do not include a [Message X] header.",
1126 selected_message_id.0
1127 ),
1128 }))
1129 .collect(),
1130 stream: true,
1131 };
1132
1133 let Some(api_key) = self.api_key.borrow().clone() else { continue };
1134 let stream = stream_completion(api_key, cx.background().clone(), request);
1135 let assistant_message = self
1136 .insert_message_after(
1137 selected_message_id,
1138 Role::Assistant,
1139 MessageStatus::Pending,
1140 cx,
1141 )
1142 .unwrap();
1143
1144 // Queue up the user's next reply
1145 if Some(selected_message_id) == last_message_id {
1146 let user_message = self
1147 .insert_message_after(
1148 assistant_message.id,
1149 Role::User,
1150 MessageStatus::Done,
1151 cx,
1152 )
1153 .unwrap();
1154 user_messages.push(user_message);
1155 }
1156
1157 tasks.push(cx.spawn_weak({
1158 |this, mut cx| async move {
1159 let assistant_message_id = assistant_message.id;
1160 let stream_completion = async {
1161 let mut messages = stream.await?;
1162
1163 while let Some(message) = messages.next().await {
1164 let mut message = message?;
1165 if let Some(choice) = message.choices.pop() {
1166 this.upgrade(&cx)
1167 .ok_or_else(|| anyhow!("conversation was dropped"))?
1168 .update(&mut cx, |this, cx| {
1169 let text: Arc<str> = choice.delta.content?.into();
1170 let message_ix = this.message_anchors.iter().position(
1171 |message| message.id == assistant_message_id,
1172 )?;
1173 this.buffer.update(cx, |buffer, cx| {
1174 let offset = this.message_anchors[message_ix + 1..]
1175 .iter()
1176 .find(|message| message.start.is_valid(buffer))
1177 .map_or(buffer.len(), |message| {
1178 message
1179 .start
1180 .to_offset(buffer)
1181 .saturating_sub(1)
1182 });
1183 buffer.edit([(offset..offset, text)], None, cx);
1184 });
1185 cx.emit(ConversationEvent::StreamedCompletion);
1186
1187 Some(())
1188 });
1189 }
1190 smol::future::yield_now().await;
1191 }
1192
1193 this.upgrade(&cx)
1194 .ok_or_else(|| anyhow!("conversation was dropped"))?
1195 .update(&mut cx, |this, cx| {
1196 this.pending_completions.retain(|completion| {
1197 completion.id != this.completion_count
1198 });
1199 this.summarize(cx);
1200 });
1201
1202 anyhow::Ok(())
1203 };
1204
1205 let result = stream_completion.await;
1206 if let Some(this) = this.upgrade(&cx) {
1207 this.update(&mut cx, |this, cx| {
1208 if let Some(metadata) =
1209 this.messages_metadata.get_mut(&assistant_message.id)
1210 {
1211 match result {
1212 Ok(_) => {
1213 metadata.status = MessageStatus::Done;
1214 }
1215 Err(error) => {
1216 metadata.status = MessageStatus::Error(
1217 error.to_string().trim().into(),
1218 );
1219 }
1220 }
1221 cx.notify();
1222 }
1223 });
1224 }
1225 }
1226 }));
1227 }
1228 }
1229
1230 if !tasks.is_empty() {
1231 self.pending_completions.push(PendingCompletion {
1232 id: post_inc(&mut self.completion_count),
1233 _tasks: tasks,
1234 });
1235 }
1236
1237 user_messages
1238 }
1239
1240 fn cancel_last_assist(&mut self) -> bool {
1241 self.pending_completions.pop().is_some()
1242 }
1243
1244 fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
1245 for id in ids {
1246 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
1247 metadata.role.cycle();
1248 cx.emit(ConversationEvent::MessagesEdited);
1249 cx.notify();
1250 }
1251 }
1252 }
1253
1254 fn insert_message_after(
1255 &mut self,
1256 message_id: MessageId,
1257 role: Role,
1258 status: MessageStatus,
1259 cx: &mut ModelContext<Self>,
1260 ) -> Option<MessageAnchor> {
1261 if let Some(prev_message_ix) = self
1262 .message_anchors
1263 .iter()
1264 .position(|message| message.id == message_id)
1265 {
1266 // Find the next valid message after the one we were given.
1267 let mut next_message_ix = prev_message_ix + 1;
1268 while let Some(next_message) = self.message_anchors.get(next_message_ix) {
1269 if next_message.start.is_valid(self.buffer.read(cx)) {
1270 break;
1271 }
1272 next_message_ix += 1;
1273 }
1274
1275 let start = self.buffer.update(cx, |buffer, cx| {
1276 let offset = self
1277 .message_anchors
1278 .get(next_message_ix)
1279 .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
1280 buffer.edit([(offset..offset, "\n")], None, cx);
1281 buffer.anchor_before(offset + 1)
1282 });
1283 let message = MessageAnchor {
1284 id: MessageId(post_inc(&mut self.next_message_id.0)),
1285 start,
1286 };
1287 self.message_anchors
1288 .insert(next_message_ix, message.clone());
1289 self.messages_metadata.insert(
1290 message.id,
1291 MessageMetadata {
1292 role,
1293 sent_at: Local::now(),
1294 status,
1295 },
1296 );
1297 cx.emit(ConversationEvent::MessagesEdited);
1298 Some(message)
1299 } else {
1300 None
1301 }
1302 }
1303
1304 fn split_message(
1305 &mut self,
1306 range: Range<usize>,
1307 cx: &mut ModelContext<Self>,
1308 ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
1309 let start_message = self.message_for_offset(range.start, cx);
1310 let end_message = self.message_for_offset(range.end, cx);
1311 if let Some((start_message, end_message)) = start_message.zip(end_message) {
1312 // Prevent splitting when range spans multiple messages.
1313 if start_message.id != end_message.id {
1314 return (None, None);
1315 }
1316
1317 let message = start_message;
1318 let role = message.role;
1319 let mut edited_buffer = false;
1320
1321 let mut suffix_start = None;
1322 if range.start > message.offset_range.start && range.end < message.offset_range.end - 1
1323 {
1324 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
1325 suffix_start = Some(range.end + 1);
1326 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
1327 suffix_start = Some(range.end);
1328 }
1329 }
1330
1331 let suffix = if let Some(suffix_start) = suffix_start {
1332 MessageAnchor {
1333 id: MessageId(post_inc(&mut self.next_message_id.0)),
1334 start: self.buffer.read(cx).anchor_before(suffix_start),
1335 }
1336 } else {
1337 self.buffer.update(cx, |buffer, cx| {
1338 buffer.edit([(range.end..range.end, "\n")], None, cx);
1339 });
1340 edited_buffer = true;
1341 MessageAnchor {
1342 id: MessageId(post_inc(&mut self.next_message_id.0)),
1343 start: self.buffer.read(cx).anchor_before(range.end + 1),
1344 }
1345 };
1346
1347 self.message_anchors
1348 .insert(message.index_range.end + 1, suffix.clone());
1349 self.messages_metadata.insert(
1350 suffix.id,
1351 MessageMetadata {
1352 role,
1353 sent_at: Local::now(),
1354 status: MessageStatus::Done,
1355 },
1356 );
1357
1358 let new_messages =
1359 if range.start == range.end || range.start == message.offset_range.start {
1360 (None, Some(suffix))
1361 } else {
1362 let mut prefix_end = None;
1363 if range.start > message.offset_range.start
1364 && range.end < message.offset_range.end - 1
1365 {
1366 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
1367 prefix_end = Some(range.start + 1);
1368 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
1369 == Some('\n')
1370 {
1371 prefix_end = Some(range.start);
1372 }
1373 }
1374
1375 let selection = if let Some(prefix_end) = prefix_end {
1376 cx.emit(ConversationEvent::MessagesEdited);
1377 MessageAnchor {
1378 id: MessageId(post_inc(&mut self.next_message_id.0)),
1379 start: self.buffer.read(cx).anchor_before(prefix_end),
1380 }
1381 } else {
1382 self.buffer.update(cx, |buffer, cx| {
1383 buffer.edit([(range.start..range.start, "\n")], None, cx)
1384 });
1385 edited_buffer = true;
1386 MessageAnchor {
1387 id: MessageId(post_inc(&mut self.next_message_id.0)),
1388 start: self.buffer.read(cx).anchor_before(range.end + 1),
1389 }
1390 };
1391
1392 self.message_anchors
1393 .insert(message.index_range.end + 1, selection.clone());
1394 self.messages_metadata.insert(
1395 selection.id,
1396 MessageMetadata {
1397 role,
1398 sent_at: Local::now(),
1399 status: MessageStatus::Done,
1400 },
1401 );
1402 (Some(selection), Some(suffix))
1403 };
1404
1405 if !edited_buffer {
1406 cx.emit(ConversationEvent::MessagesEdited);
1407 }
1408 new_messages
1409 } else {
1410 (None, None)
1411 }
1412 }
1413
1414 fn summarize(&mut self, cx: &mut ModelContext<Self>) {
1415 if self.message_anchors.len() >= 2 && self.summary.is_none() {
1416 let api_key = self.api_key.borrow().clone();
1417 if let Some(api_key) = api_key {
1418 let messages = self
1419 .messages(cx)
1420 .take(2)
1421 .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
1422 .chain(Some(RequestMessage {
1423 role: Role::User,
1424 content:
1425 "Summarize the conversation into a short title without punctuation"
1426 .into(),
1427 }));
1428 let request = OpenAIRequest {
1429 model: self.model.clone(),
1430 messages: messages.collect(),
1431 stream: true,
1432 };
1433
1434 let stream = stream_completion(api_key, cx.background().clone(), request);
1435 self.pending_summary = cx.spawn(|this, mut cx| {
1436 async move {
1437 let mut messages = stream.await?;
1438
1439 while let Some(message) = messages.next().await {
1440 let mut message = message?;
1441 if let Some(choice) = message.choices.pop() {
1442 let text = choice.delta.content.unwrap_or_default();
1443 this.update(&mut cx, |this, cx| {
1444 this.summary
1445 .get_or_insert(Default::default())
1446 .text
1447 .push_str(&text);
1448 cx.emit(ConversationEvent::SummaryChanged);
1449 });
1450 }
1451 }
1452
1453 this.update(&mut cx, |this, cx| {
1454 if let Some(summary) = this.summary.as_mut() {
1455 summary.done = true;
1456 cx.emit(ConversationEvent::SummaryChanged);
1457 }
1458 });
1459
1460 anyhow::Ok(())
1461 }
1462 .log_err()
1463 });
1464 }
1465 }
1466 }
1467
1468 fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
1469 self.messages_for_offsets([offset], cx).pop()
1470 }
1471
1472 fn messages_for_offsets(
1473 &self,
1474 offsets: impl IntoIterator<Item = usize>,
1475 cx: &AppContext,
1476 ) -> Vec<Message> {
1477 let mut result = Vec::new();
1478
1479 let mut messages = self.messages(cx).peekable();
1480 let mut offsets = offsets.into_iter().peekable();
1481 let mut current_message = messages.next();
1482 while let Some(offset) = offsets.next() {
1483 // Locate the message that contains the offset.
1484 while current_message.as_ref().map_or(false, |message| {
1485 !message.offset_range.contains(&offset) && messages.peek().is_some()
1486 }) {
1487 current_message = messages.next();
1488 }
1489 let Some(message) = current_message.as_ref() else { break };
1490
1491 // Skip offsets that are in the same message.
1492 while offsets.peek().map_or(false, |offset| {
1493 message.offset_range.contains(offset) || messages.peek().is_none()
1494 }) {
1495 offsets.next();
1496 }
1497
1498 result.push(message.clone());
1499 }
1500 result
1501 }
1502
1503 fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
1504 let buffer = self.buffer.read(cx);
1505 let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
1506 iter::from_fn(move || {
1507 while let Some((start_ix, message_anchor)) = message_anchors.next() {
1508 let metadata = self.messages_metadata.get(&message_anchor.id)?;
1509 let message_start = message_anchor.start.to_offset(buffer);
1510 let mut message_end = None;
1511 let mut end_ix = start_ix;
1512 while let Some((_, next_message)) = message_anchors.peek() {
1513 if next_message.start.is_valid(buffer) {
1514 message_end = Some(next_message.start);
1515 break;
1516 } else {
1517 end_ix += 1;
1518 message_anchors.next();
1519 }
1520 }
1521 let message_end = message_end
1522 .unwrap_or(language::Anchor::MAX)
1523 .to_offset(buffer);
1524 return Some(Message {
1525 index_range: start_ix..end_ix,
1526 offset_range: message_start..message_end,
1527 id: message_anchor.id,
1528 anchor: message_anchor.start,
1529 role: metadata.role,
1530 sent_at: metadata.sent_at,
1531 status: metadata.status.clone(),
1532 });
1533 }
1534 None
1535 })
1536 }
1537
1538 fn save(
1539 &mut self,
1540 debounce: Option<Duration>,
1541 fs: Arc<dyn Fs>,
1542 cx: &mut ModelContext<Conversation>,
1543 ) {
1544 self.pending_save = cx.spawn(|this, mut cx| async move {
1545 if let Some(debounce) = debounce {
1546 cx.background().timer(debounce).await;
1547 }
1548
1549 let (old_path, summary) = this.read_with(&cx, |this, _| {
1550 let path = this.path.clone();
1551 let summary = if let Some(summary) = this.summary.as_ref() {
1552 if summary.done {
1553 Some(summary.text.clone())
1554 } else {
1555 None
1556 }
1557 } else {
1558 None
1559 };
1560 (path, summary)
1561 });
1562
1563 if let Some(summary) = summary {
1564 let conversation = this.read_with(&cx, |this, cx| this.serialize(cx));
1565 let path = if let Some(old_path) = old_path {
1566 old_path
1567 } else {
1568 let mut discriminant = 1;
1569 let mut new_path;
1570 loop {
1571 new_path = CONVERSATIONS_DIR.join(&format!(
1572 "{} - {}.zed.json",
1573 summary.trim(),
1574 discriminant
1575 ));
1576 if fs.is_file(&new_path).await {
1577 discriminant += 1;
1578 } else {
1579 break;
1580 }
1581 }
1582 new_path
1583 };
1584
1585 fs.create_dir(CONVERSATIONS_DIR.as_ref()).await?;
1586 fs.atomic_write(path.clone(), serde_json::to_string(&conversation).unwrap())
1587 .await?;
1588 this.update(&mut cx, |this, _| this.path = Some(path));
1589 }
1590
1591 Ok(())
1592 });
1593 }
1594}
1595
1596struct PendingCompletion {
1597 id: usize,
1598 _tasks: Vec<Task<()>>,
1599}
1600
1601enum ConversationEditorEvent {
1602 TabContentChanged,
1603}
1604
1605#[derive(Copy, Clone, Debug, PartialEq)]
1606struct ScrollPosition {
1607 offset_before_cursor: Vector2F,
1608 cursor: Anchor,
1609}
1610
1611struct ConversationEditor {
1612 conversation: ModelHandle<Conversation>,
1613 fs: Arc<dyn Fs>,
1614 editor: ViewHandle<Editor>,
1615 blocks: HashSet<BlockId>,
1616 scroll_position: Option<ScrollPosition>,
1617 _subscriptions: Vec<Subscription>,
1618}
1619
1620impl ConversationEditor {
1621 fn new(
1622 api_key: Rc<RefCell<Option<String>>>,
1623 language_registry: Arc<LanguageRegistry>,
1624 fs: Arc<dyn Fs>,
1625 cx: &mut ViewContext<Self>,
1626 ) -> Self {
1627 let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
1628 Self::for_conversation(conversation, fs, cx)
1629 }
1630
1631 fn for_conversation(
1632 conversation: ModelHandle<Conversation>,
1633 fs: Arc<dyn Fs>,
1634 cx: &mut ViewContext<Self>,
1635 ) -> Self {
1636 let editor = cx.add_view(|cx| {
1637 let mut editor = Editor::for_buffer(conversation.read(cx).buffer.clone(), None, cx);
1638 editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
1639 editor.set_show_gutter(false, cx);
1640 editor
1641 });
1642
1643 let _subscriptions = vec![
1644 cx.observe(&conversation, |_, _, cx| cx.notify()),
1645 cx.subscribe(&conversation, Self::handle_conversation_event),
1646 cx.subscribe(&editor, Self::handle_editor_event),
1647 ];
1648
1649 let mut this = Self {
1650 conversation,
1651 editor,
1652 blocks: Default::default(),
1653 scroll_position: None,
1654 fs,
1655 _subscriptions,
1656 };
1657 this.update_message_headers(cx);
1658 this
1659 }
1660
1661 fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
1662 let cursors = self.cursors(cx);
1663
1664 let user_messages = self.conversation.update(cx, |conversation, cx| {
1665 let selected_messages = conversation
1666 .messages_for_offsets(cursors, cx)
1667 .into_iter()
1668 .map(|message| message.id)
1669 .collect();
1670 conversation.assist(selected_messages, cx)
1671 });
1672 let new_selections = user_messages
1673 .iter()
1674 .map(|message| {
1675 let cursor = message
1676 .start
1677 .to_offset(self.conversation.read(cx).buffer.read(cx));
1678 cursor..cursor
1679 })
1680 .collect::<Vec<_>>();
1681 if !new_selections.is_empty() {
1682 self.editor.update(cx, |editor, cx| {
1683 editor.change_selections(
1684 Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
1685 cx,
1686 |selections| selections.select_ranges(new_selections),
1687 );
1688 });
1689 // Avoid scrolling to the new cursor position so the assistant's output is stable.
1690 cx.defer(|this, _| this.scroll_position = None);
1691 }
1692 }
1693
1694 fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
1695 if !self
1696 .conversation
1697 .update(cx, |conversation, _| conversation.cancel_last_assist())
1698 {
1699 cx.propagate_action();
1700 }
1701 }
1702
1703 fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
1704 let cursors = self.cursors(cx);
1705 self.conversation.update(cx, |conversation, cx| {
1706 let messages = conversation
1707 .messages_for_offsets(cursors, cx)
1708 .into_iter()
1709 .map(|message| message.id)
1710 .collect();
1711 conversation.cycle_message_roles(messages, cx)
1712 });
1713 }
1714
1715 fn cursors(&self, cx: &AppContext) -> Vec<usize> {
1716 let selections = self.editor.read(cx).selections.all::<usize>(cx);
1717 selections
1718 .into_iter()
1719 .map(|selection| selection.head())
1720 .collect()
1721 }
1722
1723 fn handle_conversation_event(
1724 &mut self,
1725 _: ModelHandle<Conversation>,
1726 event: &ConversationEvent,
1727 cx: &mut ViewContext<Self>,
1728 ) {
1729 match event {
1730 ConversationEvent::MessagesEdited => {
1731 self.update_message_headers(cx);
1732 self.conversation.update(cx, |conversation, cx| {
1733 conversation.save(Some(Duration::from_millis(500)), self.fs.clone(), cx);
1734 });
1735 }
1736 ConversationEvent::SummaryChanged => {
1737 cx.emit(ConversationEditorEvent::TabContentChanged);
1738 self.conversation.update(cx, |conversation, cx| {
1739 conversation.save(None, self.fs.clone(), cx);
1740 });
1741 }
1742 ConversationEvent::StreamedCompletion => {
1743 self.editor.update(cx, |editor, cx| {
1744 if let Some(scroll_position) = self.scroll_position {
1745 let snapshot = editor.snapshot(cx);
1746 let cursor_point = scroll_position.cursor.to_display_point(&snapshot);
1747 let scroll_top =
1748 cursor_point.row() as f32 - scroll_position.offset_before_cursor.y();
1749 editor.set_scroll_position(
1750 vec2f(scroll_position.offset_before_cursor.x(), scroll_top),
1751 cx,
1752 );
1753 }
1754 });
1755 }
1756 }
1757 }
1758
1759 fn handle_editor_event(
1760 &mut self,
1761 _: ViewHandle<Editor>,
1762 event: &editor::Event,
1763 cx: &mut ViewContext<Self>,
1764 ) {
1765 match event {
1766 editor::Event::ScrollPositionChanged { autoscroll, .. } => {
1767 let cursor_scroll_position = self.cursor_scroll_position(cx);
1768 if *autoscroll {
1769 self.scroll_position = cursor_scroll_position;
1770 } else if self.scroll_position != cursor_scroll_position {
1771 self.scroll_position = None;
1772 }
1773 }
1774 editor::Event::SelectionsChanged { .. } => {
1775 self.scroll_position = self.cursor_scroll_position(cx);
1776 }
1777 _ => {}
1778 }
1779 }
1780
1781 fn cursor_scroll_position(&self, cx: &mut ViewContext<Self>) -> Option<ScrollPosition> {
1782 self.editor.update(cx, |editor, cx| {
1783 let snapshot = editor.snapshot(cx);
1784 let cursor = editor.selections.newest_anchor().head();
1785 let cursor_row = cursor.to_display_point(&snapshot.display_snapshot).row() as f32;
1786 let scroll_position = editor
1787 .scroll_manager
1788 .anchor()
1789 .scroll_position(&snapshot.display_snapshot);
1790
1791 let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1792 if (scroll_position.y()..scroll_bottom).contains(&cursor_row) {
1793 Some(ScrollPosition {
1794 cursor,
1795 offset_before_cursor: vec2f(
1796 scroll_position.x(),
1797 cursor_row - scroll_position.y(),
1798 ),
1799 })
1800 } else {
1801 None
1802 }
1803 })
1804 }
1805
1806 fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
1807 self.editor.update(cx, |editor, cx| {
1808 let buffer = editor.buffer().read(cx).snapshot(cx);
1809 let excerpt_id = *buffer.as_singleton().unwrap().0;
1810 let old_blocks = std::mem::take(&mut self.blocks);
1811 let new_blocks = self
1812 .conversation
1813 .read(cx)
1814 .messages(cx)
1815 .map(|message| BlockProperties {
1816 position: buffer.anchor_in_excerpt(excerpt_id, message.anchor),
1817 height: 2,
1818 style: BlockStyle::Sticky,
1819 render: Arc::new({
1820 let conversation = self.conversation.clone();
1821 // let metadata = message.metadata.clone();
1822 // let message = message.clone();
1823 move |cx| {
1824 enum Sender {}
1825 enum ErrorTooltip {}
1826
1827 let theme = theme::current(cx);
1828 let style = &theme.assistant;
1829 let message_id = message.id;
1830 let sender = MouseEventHandler::<Sender, _>::new(
1831 message_id.0,
1832 cx,
1833 |state, _| match message.role {
1834 Role::User => {
1835 let style = style.user_sender.style_for(state);
1836 Label::new("You", style.text.clone())
1837 .contained()
1838 .with_style(style.container)
1839 }
1840 Role::Assistant => {
1841 let style = style.assistant_sender.style_for(state);
1842 Label::new("Assistant", style.text.clone())
1843 .contained()
1844 .with_style(style.container)
1845 }
1846 Role::System => {
1847 let style = style.system_sender.style_for(state);
1848 Label::new("System", style.text.clone())
1849 .contained()
1850 .with_style(style.container)
1851 }
1852 },
1853 )
1854 .with_cursor_style(CursorStyle::PointingHand)
1855 .on_down(MouseButton::Left, {
1856 let conversation = conversation.clone();
1857 move |_, _, cx| {
1858 conversation.update(cx, |conversation, cx| {
1859 conversation.cycle_message_roles(
1860 HashSet::from_iter(Some(message_id)),
1861 cx,
1862 )
1863 })
1864 }
1865 });
1866
1867 Flex::row()
1868 .with_child(sender.aligned())
1869 .with_child(
1870 Label::new(
1871 message.sent_at.format("%I:%M%P").to_string(),
1872 style.sent_at.text.clone(),
1873 )
1874 .contained()
1875 .with_style(style.sent_at.container)
1876 .aligned(),
1877 )
1878 .with_children(
1879 if let MessageStatus::Error(error) = &message.status {
1880 Some(
1881 Svg::new("icons/circle_x_mark_12.svg")
1882 .with_color(style.error_icon.color)
1883 .constrained()
1884 .with_width(style.error_icon.width)
1885 .contained()
1886 .with_style(style.error_icon.container)
1887 .with_tooltip::<ErrorTooltip>(
1888 message_id.0,
1889 error.to_string(),
1890 None,
1891 theme.tooltip.clone(),
1892 cx,
1893 )
1894 .aligned(),
1895 )
1896 } else {
1897 None
1898 },
1899 )
1900 .aligned()
1901 .left()
1902 .contained()
1903 .with_style(style.message_header)
1904 .into_any()
1905 }
1906 }),
1907 disposition: BlockDisposition::Above,
1908 })
1909 .collect::<Vec<_>>();
1910
1911 editor.remove_blocks(old_blocks, None, cx);
1912 let ids = editor.insert_blocks(new_blocks, None, cx);
1913 self.blocks = HashSet::from_iter(ids);
1914 });
1915 }
1916
1917 fn quote_selection(
1918 workspace: &mut Workspace,
1919 _: &QuoteSelection,
1920 cx: &mut ViewContext<Workspace>,
1921 ) {
1922 let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1923 return;
1924 };
1925 let Some(editor) = workspace.active_item(cx).and_then(|item| item.act_as::<Editor>(cx)) else {
1926 return;
1927 };
1928
1929 let text = editor.read_with(cx, |editor, cx| {
1930 let range = editor.selections.newest::<usize>(cx).range();
1931 let buffer = editor.buffer().read(cx).snapshot(cx);
1932 let start_language = buffer.language_at(range.start);
1933 let end_language = buffer.language_at(range.end);
1934 let language_name = if start_language == end_language {
1935 start_language.map(|language| language.name())
1936 } else {
1937 None
1938 };
1939 let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1940
1941 let selected_text = buffer.text_for_range(range).collect::<String>();
1942 if selected_text.is_empty() {
1943 None
1944 } else {
1945 Some(if language_name == "markdown" {
1946 selected_text
1947 .lines()
1948 .map(|line| format!("> {}", line))
1949 .collect::<Vec<_>>()
1950 .join("\n")
1951 } else {
1952 format!("```{language_name}\n{selected_text}\n```")
1953 })
1954 }
1955 });
1956
1957 // Activate the panel
1958 if !panel.read(cx).has_focus(cx) {
1959 workspace.toggle_panel_focus::<AssistantPanel>(cx);
1960 }
1961
1962 if let Some(text) = text {
1963 panel.update(cx, |panel, cx| {
1964 let conversation = panel
1965 .active_editor()
1966 .cloned()
1967 .unwrap_or_else(|| panel.new_conversation(cx));
1968 conversation.update(cx, |conversation, cx| {
1969 conversation
1970 .editor
1971 .update(cx, |editor, cx| editor.insert(&text, cx))
1972 });
1973 });
1974 }
1975 }
1976
1977 fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1978 let editor = self.editor.read(cx);
1979 let conversation = self.conversation.read(cx);
1980 if editor.selections.count() == 1 {
1981 let selection = editor.selections.newest::<usize>(cx);
1982 let mut copied_text = String::new();
1983 let mut spanned_messages = 0;
1984 for message in conversation.messages(cx) {
1985 if message.offset_range.start >= selection.range().end {
1986 break;
1987 } else if message.offset_range.end >= selection.range().start {
1988 let range = cmp::max(message.offset_range.start, selection.range().start)
1989 ..cmp::min(message.offset_range.end, selection.range().end);
1990 if !range.is_empty() {
1991 spanned_messages += 1;
1992 write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
1993 for chunk in conversation.buffer.read(cx).text_for_range(range) {
1994 copied_text.push_str(&chunk);
1995 }
1996 copied_text.push('\n');
1997 }
1998 }
1999 }
2000
2001 if spanned_messages > 1 {
2002 cx.platform()
2003 .write_to_clipboard(ClipboardItem::new(copied_text));
2004 return;
2005 }
2006 }
2007
2008 cx.propagate_action();
2009 }
2010
2011 fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
2012 self.conversation.update(cx, |conversation, cx| {
2013 let selections = self.editor.read(cx).selections.disjoint_anchors();
2014 for selection in selections.into_iter() {
2015 let buffer = self.editor.read(cx).buffer().read(cx).snapshot(cx);
2016 let range = selection
2017 .map(|endpoint| endpoint.to_offset(&buffer))
2018 .range();
2019 conversation.split_message(range, cx);
2020 }
2021 });
2022 }
2023
2024 fn save(&mut self, _: &Save, cx: &mut ViewContext<Self>) {
2025 self.conversation.update(cx, |conversation, cx| {
2026 conversation.save(None, self.fs.clone(), cx)
2027 });
2028 }
2029
2030 fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
2031 self.conversation.update(cx, |conversation, cx| {
2032 let new_model = match conversation.model.as_str() {
2033 "gpt-4-0613" => "gpt-3.5-turbo-0613",
2034 _ => "gpt-4-0613",
2035 };
2036 conversation.set_model(new_model.into(), cx);
2037 });
2038 }
2039
2040 fn title(&self, cx: &AppContext) -> String {
2041 self.conversation
2042 .read(cx)
2043 .summary
2044 .as_ref()
2045 .map(|summary| summary.text.clone())
2046 .unwrap_or_else(|| "New Conversation".into())
2047 }
2048
2049 fn render_current_model(
2050 &self,
2051 style: &AssistantStyle,
2052 cx: &mut ViewContext<Self>,
2053 ) -> impl Element<Self> {
2054 enum Model {}
2055
2056 MouseEventHandler::<Model, _>::new(0, cx, |state, cx| {
2057 let style = style.model.style_for(state);
2058 Label::new(self.conversation.read(cx).model.clone(), style.text.clone())
2059 .contained()
2060 .with_style(style.container)
2061 })
2062 .with_cursor_style(CursorStyle::PointingHand)
2063 .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx))
2064 }
2065
2066 fn render_remaining_tokens(
2067 &self,
2068 style: &AssistantStyle,
2069 cx: &mut ViewContext<Self>,
2070 ) -> Option<impl Element<Self>> {
2071 let remaining_tokens = self.conversation.read(cx).remaining_tokens()?;
2072 let remaining_tokens_style = if remaining_tokens <= 0 {
2073 &style.no_remaining_tokens
2074 } else if remaining_tokens <= 500 {
2075 &style.low_remaining_tokens
2076 } else {
2077 &style.remaining_tokens
2078 };
2079 Some(
2080 Label::new(
2081 remaining_tokens.to_string(),
2082 remaining_tokens_style.text.clone(),
2083 )
2084 .contained()
2085 .with_style(remaining_tokens_style.container),
2086 )
2087 }
2088}
2089
2090impl Entity for ConversationEditor {
2091 type Event = ConversationEditorEvent;
2092}
2093
2094impl View for ConversationEditor {
2095 fn ui_name() -> &'static str {
2096 "ConversationEditor"
2097 }
2098
2099 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
2100 let theme = &theme::current(cx).assistant;
2101 Stack::new()
2102 .with_child(
2103 ChildView::new(&self.editor, cx)
2104 .contained()
2105 .with_style(theme.container),
2106 )
2107 .with_child(
2108 Flex::row()
2109 .with_child(self.render_current_model(theme, cx))
2110 .with_children(self.render_remaining_tokens(theme, cx))
2111 .aligned()
2112 .top()
2113 .right(),
2114 )
2115 .into_any()
2116 }
2117
2118 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
2119 if cx.is_self_focused() {
2120 cx.focus(&self.editor);
2121 }
2122 }
2123}
2124
2125#[derive(Clone, Debug)]
2126struct MessageAnchor {
2127 id: MessageId,
2128 start: language::Anchor,
2129}
2130
2131#[derive(Clone, Debug)]
2132pub struct Message {
2133 offset_range: Range<usize>,
2134 index_range: Range<usize>,
2135 id: MessageId,
2136 anchor: language::Anchor,
2137 role: Role,
2138 sent_at: DateTime<Local>,
2139 status: MessageStatus,
2140}
2141
2142impl Message {
2143 fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
2144 let mut content = format!("[Message {}]\n", self.id.0).to_string();
2145 content.extend(buffer.text_for_range(self.offset_range.clone()));
2146 RequestMessage {
2147 role: self.role,
2148 content: content.trim_end().into(),
2149 }
2150 }
2151}
2152
2153async fn stream_completion(
2154 api_key: String,
2155 executor: Arc<Background>,
2156 mut request: OpenAIRequest,
2157) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
2158 request.stream = true;
2159
2160 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
2161
2162 let json_data = serde_json::to_string(&request)?;
2163 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
2164 .header("Content-Type", "application/json")
2165 .header("Authorization", format!("Bearer {}", api_key))
2166 .body(json_data)?
2167 .send_async()
2168 .await?;
2169
2170 let status = response.status();
2171 if status == StatusCode::OK {
2172 executor
2173 .spawn(async move {
2174 let mut lines = BufReader::new(response.body_mut()).lines();
2175
2176 fn parse_line(
2177 line: Result<String, io::Error>,
2178 ) -> Result<Option<OpenAIResponseStreamEvent>> {
2179 if let Some(data) = line?.strip_prefix("data: ") {
2180 let event = serde_json::from_str(&data)?;
2181 Ok(Some(event))
2182 } else {
2183 Ok(None)
2184 }
2185 }
2186
2187 while let Some(line) = lines.next().await {
2188 if let Some(event) = parse_line(line).transpose() {
2189 let done = event.as_ref().map_or(false, |event| {
2190 event
2191 .choices
2192 .last()
2193 .map_or(false, |choice| choice.finish_reason.is_some())
2194 });
2195 if tx.unbounded_send(event).is_err() {
2196 break;
2197 }
2198
2199 if done {
2200 break;
2201 }
2202 }
2203 }
2204
2205 anyhow::Ok(())
2206 })
2207 .detach();
2208
2209 Ok(rx)
2210 } else {
2211 let mut body = String::new();
2212 response.body_mut().read_to_string(&mut body).await?;
2213
2214 #[derive(Deserialize)]
2215 struct OpenAIResponse {
2216 error: OpenAIError,
2217 }
2218
2219 #[derive(Deserialize)]
2220 struct OpenAIError {
2221 message: String,
2222 }
2223
2224 match serde_json::from_str::<OpenAIResponse>(&body) {
2225 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
2226 "Failed to connect to OpenAI API: {}",
2227 response.error.message,
2228 )),
2229
2230 _ => Err(anyhow!(
2231 "Failed to connect to OpenAI API: {} {}",
2232 response.status(),
2233 body,
2234 )),
2235 }
2236 }
2237}
2238
2239#[cfg(test)]
2240mod tests {
2241 use super::*;
2242 use crate::MessageId;
2243 use gpui::AppContext;
2244
2245 #[gpui::test]
2246 fn test_inserting_and_removing_messages(cx: &mut AppContext) {
2247 let registry = Arc::new(LanguageRegistry::test());
2248 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2249 let buffer = conversation.read(cx).buffer.clone();
2250
2251 let message_1 = conversation.read(cx).message_anchors[0].clone();
2252 assert_eq!(
2253 messages(&conversation, cx),
2254 vec![(message_1.id, Role::User, 0..0)]
2255 );
2256
2257 let message_2 = conversation.update(cx, |conversation, cx| {
2258 conversation
2259 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
2260 .unwrap()
2261 });
2262 assert_eq!(
2263 messages(&conversation, cx),
2264 vec![
2265 (message_1.id, Role::User, 0..1),
2266 (message_2.id, Role::Assistant, 1..1)
2267 ]
2268 );
2269
2270 buffer.update(cx, |buffer, cx| {
2271 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
2272 });
2273 assert_eq!(
2274 messages(&conversation, cx),
2275 vec![
2276 (message_1.id, Role::User, 0..2),
2277 (message_2.id, Role::Assistant, 2..3)
2278 ]
2279 );
2280
2281 let message_3 = conversation.update(cx, |conversation, cx| {
2282 conversation
2283 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2284 .unwrap()
2285 });
2286 assert_eq!(
2287 messages(&conversation, cx),
2288 vec![
2289 (message_1.id, Role::User, 0..2),
2290 (message_2.id, Role::Assistant, 2..4),
2291 (message_3.id, Role::User, 4..4)
2292 ]
2293 );
2294
2295 let message_4 = conversation.update(cx, |conversation, cx| {
2296 conversation
2297 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2298 .unwrap()
2299 });
2300 assert_eq!(
2301 messages(&conversation, cx),
2302 vec![
2303 (message_1.id, Role::User, 0..2),
2304 (message_2.id, Role::Assistant, 2..4),
2305 (message_4.id, Role::User, 4..5),
2306 (message_3.id, Role::User, 5..5),
2307 ]
2308 );
2309
2310 buffer.update(cx, |buffer, cx| {
2311 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
2312 });
2313 assert_eq!(
2314 messages(&conversation, cx),
2315 vec![
2316 (message_1.id, Role::User, 0..2),
2317 (message_2.id, Role::Assistant, 2..4),
2318 (message_4.id, Role::User, 4..6),
2319 (message_3.id, Role::User, 6..7),
2320 ]
2321 );
2322
2323 // Deleting across message boundaries merges the messages.
2324 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
2325 assert_eq!(
2326 messages(&conversation, cx),
2327 vec![
2328 (message_1.id, Role::User, 0..3),
2329 (message_3.id, Role::User, 3..4),
2330 ]
2331 );
2332
2333 // Undoing the deletion should also undo the merge.
2334 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2335 assert_eq!(
2336 messages(&conversation, cx),
2337 vec![
2338 (message_1.id, Role::User, 0..2),
2339 (message_2.id, Role::Assistant, 2..4),
2340 (message_4.id, Role::User, 4..6),
2341 (message_3.id, Role::User, 6..7),
2342 ]
2343 );
2344
2345 // Redoing the deletion should also redo the merge.
2346 buffer.update(cx, |buffer, cx| buffer.redo(cx));
2347 assert_eq!(
2348 messages(&conversation, cx),
2349 vec![
2350 (message_1.id, Role::User, 0..3),
2351 (message_3.id, Role::User, 3..4),
2352 ]
2353 );
2354
2355 // Ensure we can still insert after a merged message.
2356 let message_5 = conversation.update(cx, |conversation, cx| {
2357 conversation
2358 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2359 .unwrap()
2360 });
2361 assert_eq!(
2362 messages(&conversation, cx),
2363 vec![
2364 (message_1.id, Role::User, 0..3),
2365 (message_5.id, Role::System, 3..4),
2366 (message_3.id, Role::User, 4..5)
2367 ]
2368 );
2369 }
2370
2371 #[gpui::test]
2372 fn test_message_splitting(cx: &mut AppContext) {
2373 let registry = Arc::new(LanguageRegistry::test());
2374 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2375 let buffer = conversation.read(cx).buffer.clone();
2376
2377 let message_1 = conversation.read(cx).message_anchors[0].clone();
2378 assert_eq!(
2379 messages(&conversation, cx),
2380 vec![(message_1.id, Role::User, 0..0)]
2381 );
2382
2383 buffer.update(cx, |buffer, cx| {
2384 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
2385 });
2386
2387 let (_, message_2) =
2388 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2389 let message_2 = message_2.unwrap();
2390
2391 // We recycle newlines in the middle of a split message
2392 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
2393 assert_eq!(
2394 messages(&conversation, cx),
2395 vec![
2396 (message_1.id, Role::User, 0..4),
2397 (message_2.id, Role::User, 4..16),
2398 ]
2399 );
2400
2401 let (_, message_3) =
2402 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2403 let message_3 = message_3.unwrap();
2404
2405 // We don't recycle newlines at the end of a split message
2406 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2407 assert_eq!(
2408 messages(&conversation, cx),
2409 vec![
2410 (message_1.id, Role::User, 0..4),
2411 (message_3.id, Role::User, 4..5),
2412 (message_2.id, Role::User, 5..17),
2413 ]
2414 );
2415
2416 let (_, message_4) =
2417 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2418 let message_4 = message_4.unwrap();
2419 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2420 assert_eq!(
2421 messages(&conversation, cx),
2422 vec![
2423 (message_1.id, Role::User, 0..4),
2424 (message_3.id, Role::User, 4..5),
2425 (message_2.id, Role::User, 5..9),
2426 (message_4.id, Role::User, 9..17),
2427 ]
2428 );
2429
2430 let (_, message_5) =
2431 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2432 let message_5 = message_5.unwrap();
2433 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
2434 assert_eq!(
2435 messages(&conversation, cx),
2436 vec![
2437 (message_1.id, Role::User, 0..4),
2438 (message_3.id, Role::User, 4..5),
2439 (message_2.id, Role::User, 5..9),
2440 (message_4.id, Role::User, 9..10),
2441 (message_5.id, Role::User, 10..18),
2442 ]
2443 );
2444
2445 let (message_6, message_7) = conversation.update(cx, |conversation, cx| {
2446 conversation.split_message(14..16, cx)
2447 });
2448 let message_6 = message_6.unwrap();
2449 let message_7 = message_7.unwrap();
2450 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
2451 assert_eq!(
2452 messages(&conversation, cx),
2453 vec![
2454 (message_1.id, Role::User, 0..4),
2455 (message_3.id, Role::User, 4..5),
2456 (message_2.id, Role::User, 5..9),
2457 (message_4.id, Role::User, 9..10),
2458 (message_5.id, Role::User, 10..14),
2459 (message_6.id, Role::User, 14..17),
2460 (message_7.id, Role::User, 17..19),
2461 ]
2462 );
2463 }
2464
2465 #[gpui::test]
2466 fn test_messages_for_offsets(cx: &mut AppContext) {
2467 let registry = Arc::new(LanguageRegistry::test());
2468 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2469 let buffer = conversation.read(cx).buffer.clone();
2470
2471 let message_1 = conversation.read(cx).message_anchors[0].clone();
2472 assert_eq!(
2473 messages(&conversation, cx),
2474 vec![(message_1.id, Role::User, 0..0)]
2475 );
2476
2477 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
2478 let message_2 = conversation
2479 .update(cx, |conversation, cx| {
2480 conversation.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
2481 })
2482 .unwrap();
2483 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
2484
2485 let message_3 = conversation
2486 .update(cx, |conversation, cx| {
2487 conversation.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2488 })
2489 .unwrap();
2490 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
2491
2492 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
2493 assert_eq!(
2494 messages(&conversation, cx),
2495 vec![
2496 (message_1.id, Role::User, 0..4),
2497 (message_2.id, Role::User, 4..8),
2498 (message_3.id, Role::User, 8..11)
2499 ]
2500 );
2501
2502 assert_eq!(
2503 message_ids_for_offsets(&conversation, &[0, 4, 9], cx),
2504 [message_1.id, message_2.id, message_3.id]
2505 );
2506 assert_eq!(
2507 message_ids_for_offsets(&conversation, &[0, 1, 11], cx),
2508 [message_1.id, message_3.id]
2509 );
2510
2511 let message_4 = conversation
2512 .update(cx, |conversation, cx| {
2513 conversation.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
2514 })
2515 .unwrap();
2516 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
2517 assert_eq!(
2518 messages(&conversation, cx),
2519 vec![
2520 (message_1.id, Role::User, 0..4),
2521 (message_2.id, Role::User, 4..8),
2522 (message_3.id, Role::User, 8..12),
2523 (message_4.id, Role::User, 12..12)
2524 ]
2525 );
2526 assert_eq!(
2527 message_ids_for_offsets(&conversation, &[0, 4, 8, 12], cx),
2528 [message_1.id, message_2.id, message_3.id, message_4.id]
2529 );
2530
2531 fn message_ids_for_offsets(
2532 conversation: &ModelHandle<Conversation>,
2533 offsets: &[usize],
2534 cx: &AppContext,
2535 ) -> Vec<MessageId> {
2536 conversation
2537 .read(cx)
2538 .messages_for_offsets(offsets.iter().copied(), cx)
2539 .into_iter()
2540 .map(|message| message.id)
2541 .collect()
2542 }
2543 }
2544
2545 #[gpui::test]
2546 fn test_serialization(cx: &mut AppContext) {
2547 let registry = Arc::new(LanguageRegistry::test());
2548 let conversation =
2549 cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
2550 let buffer = conversation.read(cx).buffer.clone();
2551 let message_0 = conversation.read(cx).message_anchors[0].id;
2552 let message_1 = conversation.update(cx, |conversation, cx| {
2553 conversation
2554 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
2555 .unwrap()
2556 });
2557 let message_2 = conversation.update(cx, |conversation, cx| {
2558 conversation
2559 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2560 .unwrap()
2561 });
2562 buffer.update(cx, |buffer, cx| {
2563 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
2564 buffer.finalize_last_transaction();
2565 });
2566 let _message_3 = conversation.update(cx, |conversation, cx| {
2567 conversation
2568 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
2569 .unwrap()
2570 });
2571 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2572 assert_eq!(buffer.read(cx).text(), "a\nb\nc\n");
2573 assert_eq!(
2574 messages(&conversation, cx),
2575 [
2576 (message_0, Role::User, 0..2),
2577 (message_1.id, Role::Assistant, 2..6),
2578 (message_2.id, Role::System, 6..6),
2579 ]
2580 );
2581
2582 let deserialized_conversation = cx.add_model(|cx| {
2583 Conversation::deserialize(
2584 conversation.read(cx).serialize(cx),
2585 Default::default(),
2586 Default::default(),
2587 registry.clone(),
2588 cx,
2589 )
2590 });
2591 let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone();
2592 assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n");
2593 assert_eq!(
2594 messages(&deserialized_conversation, cx),
2595 [
2596 (message_0, Role::User, 0..2),
2597 (message_1.id, Role::Assistant, 2..6),
2598 (message_2.id, Role::System, 6..6),
2599 ]
2600 );
2601 }
2602
2603 fn messages(
2604 conversation: &ModelHandle<Conversation>,
2605 cx: &AppContext,
2606 ) -> Vec<(MessageId, Role, Range<usize>)> {
2607 conversation
2608 .read(cx)
2609 .messages(cx)
2610 .map(|message| (message.id, message.role, message.offset_range))
2611 .collect()
2612 }
2613}