1use crate::{
2 assistant_settings::{AssistantDockPosition, AssistantSettings},
3 MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
4 RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
5};
6use anyhow::{anyhow, Result};
7use chrono::{DateTime, Local};
8use collections::{HashMap, HashSet};
9use editor::{
10 display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
11 scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
12 Anchor, Editor, ToOffset,
13};
14use fs::Fs;
15use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
16use gpui::{
17 actions,
18 elements::*,
19 executor::Background,
20 geometry::vector::{vec2f, Vector2F},
21 platform::{CursorStyle, MouseButton},
22 Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
23 Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
24};
25use isahc::{http::StatusCode, Request, RequestExt};
26use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
27use search::BufferSearchBar;
28use serde::Deserialize;
29use settings::SettingsStore;
30use std::{
31 cell::RefCell,
32 cmp, env,
33 fmt::Write,
34 io, iter,
35 ops::Range,
36 path::{Path, PathBuf},
37 rc::Rc,
38 sync::Arc,
39 time::Duration,
40};
41use theme::AssistantStyle;
42use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
43use workspace::{
44 dock::{DockPosition, Panel},
45 searchable::Direction,
46 Save, ToggleZoom, Toolbar, Workspace,
47};
48
49const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
50
51actions!(
52 assistant,
53 [
54 NewConversation,
55 Assist,
56 Split,
57 CycleMessageRole,
58 QuoteSelection,
59 ToggleFocus,
60 ResetKey,
61 ]
62);
63
64pub fn init(cx: &mut AppContext) {
65 settings::register::<AssistantSettings>(cx);
66 cx.add_action(
67 |this: &mut AssistantPanel,
68 _: &workspace::NewFile,
69 cx: &mut ViewContext<AssistantPanel>| {
70 this.new_conversation(cx);
71 },
72 );
73 cx.add_action(ConversationEditor::assist);
74 cx.capture_action(ConversationEditor::cancel_last_assist);
75 cx.capture_action(ConversationEditor::save);
76 cx.add_action(ConversationEditor::quote_selection);
77 cx.capture_action(ConversationEditor::copy);
78 cx.add_action(ConversationEditor::split);
79 cx.capture_action(ConversationEditor::cycle_message_role);
80 cx.add_action(AssistantPanel::save_api_key);
81 cx.add_action(AssistantPanel::reset_api_key);
82 cx.add_action(AssistantPanel::toggle_zoom);
83 cx.add_action(AssistantPanel::deploy);
84 cx.add_action(AssistantPanel::select_next_match);
85 cx.add_action(AssistantPanel::select_prev_match);
86 cx.add_action(AssistantPanel::handle_editor_cancel);
87 cx.add_action(
88 |workspace: &mut Workspace, _: &ToggleFocus, cx: &mut ViewContext<Workspace>| {
89 workspace.toggle_panel_focus::<AssistantPanel>(cx);
90 },
91 );
92}
93
94#[derive(Debug)]
95pub enum AssistantPanelEvent {
96 ZoomIn,
97 ZoomOut,
98 Focus,
99 Close,
100 DockPositionChanged,
101}
102
103pub struct AssistantPanel {
104 workspace: WeakViewHandle<Workspace>,
105 width: Option<f32>,
106 height: Option<f32>,
107 active_editor_index: Option<usize>,
108 prev_active_editor_index: Option<usize>,
109 editors: Vec<ViewHandle<ConversationEditor>>,
110 saved_conversations: Vec<SavedConversationMetadata>,
111 saved_conversations_list_state: UniformListState,
112 zoomed: bool,
113 has_focus: bool,
114 toolbar: ViewHandle<Toolbar>,
115 api_key: Rc<RefCell<Option<String>>>,
116 api_key_editor: Option<ViewHandle<Editor>>,
117 has_read_credentials: bool,
118 languages: Arc<LanguageRegistry>,
119 fs: Arc<dyn Fs>,
120 subscriptions: Vec<Subscription>,
121 _watch_saved_conversations: Task<Result<()>>,
122}
123
124impl AssistantPanel {
125 pub fn load(
126 workspace: WeakViewHandle<Workspace>,
127 cx: AsyncAppContext,
128 ) -> Task<Result<ViewHandle<Self>>> {
129 cx.spawn(|mut cx| async move {
130 let fs = workspace.read_with(&cx, |workspace, _| workspace.app_state().fs.clone())?;
131 let saved_conversations = SavedConversationMetadata::list(fs.clone())
132 .await
133 .log_err()
134 .unwrap_or_default();
135
136 // TODO: deserialize state.
137 let workspace_handle = workspace.clone();
138 workspace.update(&mut cx, |workspace, cx| {
139 cx.add_view::<Self, _>(|cx| {
140 const CONVERSATION_WATCH_DURATION: Duration = Duration::from_millis(100);
141 let _watch_saved_conversations = cx.spawn(move |this, mut cx| async move {
142 let mut events = fs
143 .watch(&CONVERSATIONS_DIR, CONVERSATION_WATCH_DURATION)
144 .await;
145 while events.next().await.is_some() {
146 let saved_conversations = SavedConversationMetadata::list(fs.clone())
147 .await
148 .log_err()
149 .unwrap_or_default();
150 this.update(&mut cx, |this, cx| {
151 this.saved_conversations = saved_conversations;
152 cx.notify();
153 })
154 .ok();
155 }
156
157 anyhow::Ok(())
158 });
159
160 let toolbar = cx.add_view(|cx| {
161 let mut toolbar = Toolbar::new(None);
162 toolbar.set_can_navigate(false, cx);
163 toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
164 toolbar
165 });
166 let mut this = Self {
167 workspace: workspace_handle,
168 active_editor_index: Default::default(),
169 prev_active_editor_index: Default::default(),
170 editors: Default::default(),
171 saved_conversations,
172 saved_conversations_list_state: Default::default(),
173 zoomed: false,
174 has_focus: false,
175 toolbar,
176 api_key: Rc::new(RefCell::new(None)),
177 api_key_editor: None,
178 has_read_credentials: false,
179 languages: workspace.app_state().languages.clone(),
180 fs: workspace.app_state().fs.clone(),
181 width: None,
182 height: None,
183 subscriptions: Default::default(),
184 _watch_saved_conversations,
185 };
186
187 let mut old_dock_position = this.position(cx);
188 this.subscriptions =
189 vec![cx.observe_global::<SettingsStore, _>(move |this, cx| {
190 let new_dock_position = this.position(cx);
191 if new_dock_position != old_dock_position {
192 old_dock_position = new_dock_position;
193 cx.emit(AssistantPanelEvent::DockPositionChanged);
194 }
195 cx.notify();
196 })];
197
198 this
199 })
200 })
201 })
202 }
203
204 fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
205 let editor = cx.add_view(|cx| {
206 ConversationEditor::new(
207 self.api_key.clone(),
208 self.languages.clone(),
209 self.fs.clone(),
210 cx,
211 )
212 });
213 self.add_conversation(editor.clone(), cx);
214 editor
215 }
216
217 fn add_conversation(
218 &mut self,
219 editor: ViewHandle<ConversationEditor>,
220 cx: &mut ViewContext<Self>,
221 ) {
222 self.subscriptions
223 .push(cx.subscribe(&editor, Self::handle_conversation_editor_event));
224
225 let conversation = editor.read(cx).conversation.clone();
226 self.subscriptions
227 .push(cx.observe(&conversation, |_, _, cx| cx.notify()));
228
229 let index = self.editors.len();
230 self.editors.push(editor);
231 self.set_active_editor_index(Some(index), cx);
232 }
233
234 fn set_active_editor_index(&mut self, index: Option<usize>, cx: &mut ViewContext<Self>) {
235 self.prev_active_editor_index = self.active_editor_index;
236 self.active_editor_index = index;
237 if let Some(editor) = self.active_editor() {
238 let editor = editor.read(cx).editor.clone();
239 self.toolbar.update(cx, |toolbar, cx| {
240 toolbar.set_active_item(Some(&editor), cx);
241 });
242 if self.has_focus(cx) {
243 cx.focus(&editor);
244 }
245 } else {
246 self.toolbar.update(cx, |toolbar, cx| {
247 toolbar.set_active_item(None, cx);
248 });
249 }
250
251 cx.notify();
252 }
253
254 fn handle_conversation_editor_event(
255 &mut self,
256 _: ViewHandle<ConversationEditor>,
257 event: &ConversationEditorEvent,
258 cx: &mut ViewContext<Self>,
259 ) {
260 match event {
261 ConversationEditorEvent::TabContentChanged => cx.notify(),
262 }
263 }
264
265 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
266 if let Some(api_key) = self
267 .api_key_editor
268 .as_ref()
269 .map(|editor| editor.read(cx).text(cx))
270 {
271 if !api_key.is_empty() {
272 cx.platform()
273 .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
274 .log_err();
275 *self.api_key.borrow_mut() = Some(api_key);
276 self.api_key_editor.take();
277 cx.focus_self();
278 cx.notify();
279 }
280 } else {
281 cx.propagate_action();
282 }
283 }
284
285 fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
286 cx.platform().delete_credentials(OPENAI_API_URL).log_err();
287 self.api_key.take();
288 self.api_key_editor = Some(build_api_key_editor(cx));
289 cx.focus_self();
290 cx.notify();
291 }
292
293 fn toggle_zoom(&mut self, _: &workspace::ToggleZoom, cx: &mut ViewContext<Self>) {
294 if self.zoomed {
295 cx.emit(AssistantPanelEvent::ZoomOut)
296 } else {
297 cx.emit(AssistantPanelEvent::ZoomIn)
298 }
299 }
300
301 fn deploy(&mut self, action: &search::buffer_search::Deploy, cx: &mut ViewContext<Self>) {
302 let mut propagate_action = true;
303 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
304 search_bar.update(cx, |search_bar, cx| {
305 if search_bar.show(cx) {
306 search_bar.search_suggested(cx);
307 if action.focus {
308 search_bar.select_query(cx);
309 cx.focus_self();
310 }
311 propagate_action = false
312 }
313 });
314 }
315 if propagate_action {
316 cx.propagate_action();
317 }
318 }
319
320 fn handle_editor_cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
321 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
322 if !search_bar.read(cx).is_dismissed() {
323 search_bar.update(cx, |search_bar, cx| {
324 search_bar.dismiss(&Default::default(), cx)
325 });
326 return;
327 }
328 }
329 cx.propagate_action();
330 }
331
332 fn select_next_match(&mut self, _: &search::SelectNextMatch, cx: &mut ViewContext<Self>) {
333 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
334 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Next, 1, cx));
335 }
336 }
337
338 fn select_prev_match(&mut self, _: &search::SelectPrevMatch, cx: &mut ViewContext<Self>) {
339 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
340 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Prev, 1, cx));
341 }
342 }
343
344 fn active_editor(&self) -> Option<&ViewHandle<ConversationEditor>> {
345 self.editors.get(self.active_editor_index?)
346 }
347
348 fn render_hamburger_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
349 enum History {}
350 let theme = theme::current(cx);
351 let tooltip_style = theme::current(cx).tooltip.clone();
352 MouseEventHandler::new::<History, _>(0, cx, |state, _| {
353 let style = theme.assistant.hamburger_button.style_for(state);
354 Svg::for_style(style.icon.clone())
355 .contained()
356 .with_style(style.container)
357 })
358 .with_cursor_style(CursorStyle::PointingHand)
359 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
360 if this.active_editor().is_some() {
361 this.set_active_editor_index(None, cx);
362 } else {
363 this.set_active_editor_index(this.prev_active_editor_index, cx);
364 }
365 })
366 .with_tooltip::<History>(1, "History", None, tooltip_style, cx)
367 }
368
369 fn render_editor_tools(&self, cx: &mut ViewContext<Self>) -> Vec<AnyElement<Self>> {
370 if self.active_editor().is_some() {
371 vec![
372 Self::render_split_button(cx).into_any(),
373 Self::render_quote_button(cx).into_any(),
374 Self::render_assist_button(cx).into_any(),
375 ]
376 } else {
377 Default::default()
378 }
379 }
380
381 fn render_split_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
382 let theme = theme::current(cx);
383 let tooltip_style = theme::current(cx).tooltip.clone();
384 MouseEventHandler::new::<Split, _>(0, cx, |state, _| {
385 let style = theme.assistant.split_button.style_for(state);
386 Svg::for_style(style.icon.clone())
387 .contained()
388 .with_style(style.container)
389 })
390 .with_cursor_style(CursorStyle::PointingHand)
391 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
392 if let Some(active_editor) = this.active_editor() {
393 active_editor.update(cx, |editor, cx| editor.split(&Default::default(), cx));
394 }
395 })
396 .with_tooltip::<Split>(
397 1,
398 "Split Message",
399 Some(Box::new(Split)),
400 tooltip_style,
401 cx,
402 )
403 }
404
405 fn render_assist_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
406 let theme = theme::current(cx);
407 let tooltip_style = theme::current(cx).tooltip.clone();
408 MouseEventHandler::new::<Assist, _>(0, cx, |state, _| {
409 let style = theme.assistant.assist_button.style_for(state);
410 Svg::for_style(style.icon.clone())
411 .contained()
412 .with_style(style.container)
413 })
414 .with_cursor_style(CursorStyle::PointingHand)
415 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
416 if let Some(active_editor) = this.active_editor() {
417 active_editor.update(cx, |editor, cx| editor.assist(&Default::default(), cx));
418 }
419 })
420 .with_tooltip::<Assist>(1, "Assist", Some(Box::new(Assist)), tooltip_style, cx)
421 }
422
423 fn render_quote_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
424 let theme = theme::current(cx);
425 let tooltip_style = theme::current(cx).tooltip.clone();
426 MouseEventHandler::new::<QuoteSelection, _>(0, cx, |state, _| {
427 let style = theme.assistant.quote_button.style_for(state);
428 Svg::for_style(style.icon.clone())
429 .contained()
430 .with_style(style.container)
431 })
432 .with_cursor_style(CursorStyle::PointingHand)
433 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
434 if let Some(workspace) = this.workspace.upgrade(cx) {
435 cx.window_context().defer(move |cx| {
436 workspace.update(cx, |workspace, cx| {
437 ConversationEditor::quote_selection(workspace, &Default::default(), cx)
438 });
439 });
440 }
441 })
442 .with_tooltip::<QuoteSelection>(
443 1,
444 "Quote Selection",
445 Some(Box::new(QuoteSelection)),
446 tooltip_style,
447 cx,
448 )
449 }
450
451 fn render_plus_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
452 let theme = theme::current(cx);
453 let tooltip_style = theme::current(cx).tooltip.clone();
454 MouseEventHandler::new::<NewConversation, _>(0, cx, |state, _| {
455 let style = theme.assistant.plus_button.style_for(state);
456 Svg::for_style(style.icon.clone())
457 .contained()
458 .with_style(style.container)
459 })
460 .with_cursor_style(CursorStyle::PointingHand)
461 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
462 this.new_conversation(cx);
463 })
464 .with_tooltip::<NewConversation>(
465 1,
466 "New Conversation",
467 Some(Box::new(NewConversation)),
468 tooltip_style,
469 cx,
470 )
471 }
472
473 fn render_zoom_button(&self, cx: &mut ViewContext<Self>) -> impl Element<Self> {
474 enum ToggleZoomButton {}
475
476 let theme = theme::current(cx);
477 let tooltip_style = theme::current(cx).tooltip.clone();
478 let style = if self.zoomed {
479 &theme.assistant.zoom_out_button
480 } else {
481 &theme.assistant.zoom_in_button
482 };
483
484 MouseEventHandler::new::<ToggleZoomButton, _>(0, cx, |state, _| {
485 let style = style.style_for(state);
486 Svg::for_style(style.icon.clone())
487 .contained()
488 .with_style(style.container)
489 })
490 .with_cursor_style(CursorStyle::PointingHand)
491 .on_click(MouseButton::Left, |_, this, cx| {
492 this.toggle_zoom(&ToggleZoom, cx);
493 })
494 .with_tooltip::<ToggleZoom>(
495 0,
496 if self.zoomed { "Zoom Out" } else { "Zoom In" },
497 Some(Box::new(ToggleZoom)),
498 tooltip_style,
499 cx,
500 )
501 }
502
503 fn render_saved_conversation(
504 &mut self,
505 index: usize,
506 cx: &mut ViewContext<Self>,
507 ) -> impl Element<Self> {
508 let conversation = &self.saved_conversations[index];
509 let path = conversation.path.clone();
510 MouseEventHandler::new::<SavedConversationMetadata, _>(index, cx, move |state, cx| {
511 let style = &theme::current(cx).assistant.saved_conversation;
512 Flex::row()
513 .with_child(
514 Label::new(
515 conversation.mtime.format("%F %I:%M%p").to_string(),
516 style.saved_at.text.clone(),
517 )
518 .aligned()
519 .contained()
520 .with_style(style.saved_at.container),
521 )
522 .with_child(
523 Label::new(conversation.title.clone(), style.title.text.clone())
524 .aligned()
525 .contained()
526 .with_style(style.title.container),
527 )
528 .contained()
529 .with_style(*style.container.style_for(state))
530 })
531 .with_cursor_style(CursorStyle::PointingHand)
532 .on_click(MouseButton::Left, move |_, this, cx| {
533 this.open_conversation(path.clone(), cx)
534 .detach_and_log_err(cx)
535 })
536 }
537
538 fn open_conversation(&mut self, path: PathBuf, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
539 if let Some(ix) = self.editor_index_for_path(&path, cx) {
540 self.set_active_editor_index(Some(ix), cx);
541 return Task::ready(Ok(()));
542 }
543
544 let fs = self.fs.clone();
545 let api_key = self.api_key.clone();
546 let languages = self.languages.clone();
547 cx.spawn(|this, mut cx| async move {
548 let saved_conversation = fs.load(&path).await?;
549 let saved_conversation = serde_json::from_str(&saved_conversation)?;
550 let conversation = cx.add_model(|cx| {
551 Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
552 });
553 this.update(&mut cx, |this, cx| {
554 // If, by the time we've loaded the conversation, the user has already opened
555 // the same conversation, we don't want to open it again.
556 if let Some(ix) = this.editor_index_for_path(&path, cx) {
557 this.set_active_editor_index(Some(ix), cx);
558 } else {
559 let editor = cx
560 .add_view(|cx| ConversationEditor::for_conversation(conversation, fs, cx));
561 this.add_conversation(editor, cx);
562 }
563 })?;
564 Ok(())
565 })
566 }
567
568 fn editor_index_for_path(&self, path: &Path, cx: &AppContext) -> Option<usize> {
569 self.editors
570 .iter()
571 .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
572 }
573}
574
575fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
576 cx.add_view(|cx| {
577 let mut editor = Editor::single_line(
578 Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
579 cx,
580 );
581 editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
582 editor
583 })
584}
585
586impl Entity for AssistantPanel {
587 type Event = AssistantPanelEvent;
588}
589
590impl View for AssistantPanel {
591 fn ui_name() -> &'static str {
592 "AssistantPanel"
593 }
594
595 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
596 let theme = &theme::current(cx);
597 let style = &theme.assistant;
598 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
599 Flex::column()
600 .with_child(
601 Text::new(
602 "Paste your OpenAI API key and press Enter to use the assistant",
603 style.api_key_prompt.text.clone(),
604 )
605 .aligned(),
606 )
607 .with_child(
608 ChildView::new(api_key_editor, cx)
609 .contained()
610 .with_style(style.api_key_editor.container)
611 .aligned(),
612 )
613 .contained()
614 .with_style(style.api_key_prompt.container)
615 .aligned()
616 .into_any()
617 } else {
618 let title = self.active_editor().map(|editor| {
619 Label::new(editor.read(cx).title(cx), style.title.text.clone())
620 .contained()
621 .with_style(style.title.container)
622 .aligned()
623 .left()
624 .flex(1., false)
625 });
626 let mut header = Flex::row()
627 .with_child(Self::render_hamburger_button(cx).aligned())
628 .with_children(title);
629 if self.has_focus {
630 header.add_children(
631 self.render_editor_tools(cx)
632 .into_iter()
633 .map(|tool| tool.aligned().flex_float()),
634 );
635 header.add_child(Self::render_plus_button(cx).aligned().flex_float());
636 header.add_child(self.render_zoom_button(cx).aligned());
637 }
638
639 Flex::column()
640 .with_child(
641 header
642 .contained()
643 .with_style(theme.workspace.tab_bar.container)
644 .expanded()
645 .constrained()
646 .with_height(theme.workspace.tab_bar.height),
647 )
648 .with_children(if self.toolbar.read(cx).hidden() {
649 None
650 } else {
651 Some(ChildView::new(&self.toolbar, cx).expanded())
652 })
653 .with_child(if let Some(editor) = self.active_editor() {
654 ChildView::new(editor, cx).flex(1., true).into_any()
655 } else {
656 UniformList::new(
657 self.saved_conversations_list_state.clone(),
658 self.saved_conversations.len(),
659 cx,
660 |this, range, items, cx| {
661 for ix in range {
662 items.push(this.render_saved_conversation(ix, cx).into_any());
663 }
664 },
665 )
666 .flex(1., true)
667 .into_any()
668 })
669 .into_any()
670 }
671 }
672
673 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
674 self.has_focus = true;
675 self.toolbar
676 .update(cx, |toolbar, cx| toolbar.focus_changed(true, cx));
677 cx.notify();
678 if cx.is_self_focused() {
679 if let Some(editor) = self.active_editor() {
680 cx.focus(editor);
681 } else if let Some(api_key_editor) = self.api_key_editor.as_ref() {
682 cx.focus(api_key_editor);
683 }
684 }
685 }
686
687 fn focus_out(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
688 self.has_focus = false;
689 self.toolbar
690 .update(cx, |toolbar, cx| toolbar.focus_changed(false, cx));
691 cx.notify();
692 }
693}
694
695impl Panel for AssistantPanel {
696 fn position(&self, cx: &WindowContext) -> DockPosition {
697 match settings::get::<AssistantSettings>(cx).dock {
698 AssistantDockPosition::Left => DockPosition::Left,
699 AssistantDockPosition::Bottom => DockPosition::Bottom,
700 AssistantDockPosition::Right => DockPosition::Right,
701 }
702 }
703
704 fn position_is_valid(&self, _: DockPosition) -> bool {
705 true
706 }
707
708 fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
709 settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
710 let dock = match position {
711 DockPosition::Left => AssistantDockPosition::Left,
712 DockPosition::Bottom => AssistantDockPosition::Bottom,
713 DockPosition::Right => AssistantDockPosition::Right,
714 };
715 settings.dock = Some(dock);
716 });
717 }
718
719 fn size(&self, cx: &WindowContext) -> f32 {
720 let settings = settings::get::<AssistantSettings>(cx);
721 match self.position(cx) {
722 DockPosition::Left | DockPosition::Right => {
723 self.width.unwrap_or_else(|| settings.default_width)
724 }
725 DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
726 }
727 }
728
729 fn set_size(&mut self, size: f32, cx: &mut ViewContext<Self>) {
730 match self.position(cx) {
731 DockPosition::Left | DockPosition::Right => self.width = Some(size),
732 DockPosition::Bottom => self.height = Some(size),
733 }
734 cx.notify();
735 }
736
737 fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
738 matches!(event, AssistantPanelEvent::ZoomIn)
739 }
740
741 fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
742 matches!(event, AssistantPanelEvent::ZoomOut)
743 }
744
745 fn is_zoomed(&self, _: &WindowContext) -> bool {
746 self.zoomed
747 }
748
749 fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
750 self.zoomed = zoomed;
751 cx.notify();
752 }
753
754 fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
755 if active {
756 if self.api_key.borrow().is_none() && !self.has_read_credentials {
757 self.has_read_credentials = true;
758 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
759 Some(api_key)
760 } else if let Some((_, api_key)) = cx
761 .platform()
762 .read_credentials(OPENAI_API_URL)
763 .log_err()
764 .flatten()
765 {
766 String::from_utf8(api_key).log_err()
767 } else {
768 None
769 };
770 if let Some(api_key) = api_key {
771 *self.api_key.borrow_mut() = Some(api_key);
772 } else if self.api_key_editor.is_none() {
773 self.api_key_editor = Some(build_api_key_editor(cx));
774 cx.notify();
775 }
776 }
777
778 if self.editors.is_empty() {
779 self.new_conversation(cx);
780 }
781 }
782 }
783
784 fn icon_path(&self, cx: &WindowContext) -> Option<&'static str> {
785 settings::get::<AssistantSettings>(cx)
786 .button
787 .then(|| "icons/ai.svg")
788 }
789
790 fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
791 ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
792 }
793
794 fn should_change_position_on_event(event: &Self::Event) -> bool {
795 matches!(event, AssistantPanelEvent::DockPositionChanged)
796 }
797
798 fn should_activate_on_event(_: &Self::Event) -> bool {
799 false
800 }
801
802 fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
803 matches!(event, AssistantPanelEvent::Close)
804 }
805
806 fn has_focus(&self, _: &WindowContext) -> bool {
807 self.has_focus
808 }
809
810 fn is_focus_event(event: &Self::Event) -> bool {
811 matches!(event, AssistantPanelEvent::Focus)
812 }
813}
814
815enum ConversationEvent {
816 MessagesEdited,
817 SummaryChanged,
818 StreamedCompletion,
819}
820
821#[derive(Default)]
822struct Summary {
823 text: String,
824 done: bool,
825}
826
827struct Conversation {
828 buffer: ModelHandle<Buffer>,
829 message_anchors: Vec<MessageAnchor>,
830 messages_metadata: HashMap<MessageId, MessageMetadata>,
831 next_message_id: MessageId,
832 summary: Option<Summary>,
833 pending_summary: Task<Option<()>>,
834 completion_count: usize,
835 pending_completions: Vec<PendingCompletion>,
836 model: String,
837 token_count: Option<usize>,
838 max_token_count: usize,
839 pending_token_count: Task<Option<()>>,
840 api_key: Rc<RefCell<Option<String>>>,
841 pending_save: Task<Result<()>>,
842 path: Option<PathBuf>,
843 _subscriptions: Vec<Subscription>,
844}
845
846impl Entity for Conversation {
847 type Event = ConversationEvent;
848}
849
850impl Conversation {
851 fn new(
852 api_key: Rc<RefCell<Option<String>>>,
853 language_registry: Arc<LanguageRegistry>,
854 cx: &mut ModelContext<Self>,
855 ) -> Self {
856 let model = "gpt-3.5-turbo-0613";
857 let markdown = language_registry.language_for_name("Markdown");
858 let buffer = cx.add_model(|cx| {
859 let mut buffer = Buffer::new(0, "", cx);
860 buffer.set_language_registry(language_registry);
861 cx.spawn_weak(|buffer, mut cx| async move {
862 let markdown = markdown.await?;
863 let buffer = buffer
864 .upgrade(&cx)
865 .ok_or_else(|| anyhow!("buffer was dropped"))?;
866 buffer.update(&mut cx, |buffer, cx| {
867 buffer.set_language(Some(markdown), cx)
868 });
869 anyhow::Ok(())
870 })
871 .detach_and_log_err(cx);
872 buffer
873 });
874
875 let mut this = Self {
876 message_anchors: Default::default(),
877 messages_metadata: Default::default(),
878 next_message_id: Default::default(),
879 summary: None,
880 pending_summary: Task::ready(None),
881 completion_count: Default::default(),
882 pending_completions: Default::default(),
883 token_count: None,
884 max_token_count: tiktoken_rs::model::get_context_size(model),
885 pending_token_count: Task::ready(None),
886 model: model.into(),
887 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
888 pending_save: Task::ready(Ok(())),
889 path: None,
890 api_key,
891 buffer,
892 };
893 let message = MessageAnchor {
894 id: MessageId(post_inc(&mut this.next_message_id.0)),
895 start: language::Anchor::MIN,
896 };
897 this.message_anchors.push(message.clone());
898 this.messages_metadata.insert(
899 message.id,
900 MessageMetadata {
901 role: Role::User,
902 sent_at: Local::now(),
903 status: MessageStatus::Done,
904 },
905 );
906
907 this.count_remaining_tokens(cx);
908 this
909 }
910
911 fn serialize(&self, cx: &AppContext) -> SavedConversation {
912 SavedConversation {
913 zed: "conversation".into(),
914 version: SavedConversation::VERSION.into(),
915 text: self.buffer.read(cx).text(),
916 message_metadata: self.messages_metadata.clone(),
917 messages: self
918 .messages(cx)
919 .map(|message| SavedMessage {
920 id: message.id,
921 start: message.offset_range.start,
922 })
923 .collect(),
924 summary: self
925 .summary
926 .as_ref()
927 .map(|summary| summary.text.clone())
928 .unwrap_or_default(),
929 model: self.model.clone(),
930 }
931 }
932
933 fn deserialize(
934 saved_conversation: SavedConversation,
935 path: PathBuf,
936 api_key: Rc<RefCell<Option<String>>>,
937 language_registry: Arc<LanguageRegistry>,
938 cx: &mut ModelContext<Self>,
939 ) -> Self {
940 let model = saved_conversation.model;
941 let markdown = language_registry.language_for_name("Markdown");
942 let mut message_anchors = Vec::new();
943 let mut next_message_id = MessageId(0);
944 let buffer = cx.add_model(|cx| {
945 let mut buffer = Buffer::new(0, saved_conversation.text, cx);
946 for message in saved_conversation.messages {
947 message_anchors.push(MessageAnchor {
948 id: message.id,
949 start: buffer.anchor_before(message.start),
950 });
951 next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
952 }
953 buffer.set_language_registry(language_registry);
954 cx.spawn_weak(|buffer, mut cx| async move {
955 let markdown = markdown.await?;
956 let buffer = buffer
957 .upgrade(&cx)
958 .ok_or_else(|| anyhow!("buffer was dropped"))?;
959 buffer.update(&mut cx, |buffer, cx| {
960 buffer.set_language(Some(markdown), cx)
961 });
962 anyhow::Ok(())
963 })
964 .detach_and_log_err(cx);
965 buffer
966 });
967
968 let mut this = Self {
969 message_anchors,
970 messages_metadata: saved_conversation.message_metadata,
971 next_message_id,
972 summary: Some(Summary {
973 text: saved_conversation.summary,
974 done: true,
975 }),
976 pending_summary: Task::ready(None),
977 completion_count: Default::default(),
978 pending_completions: Default::default(),
979 token_count: None,
980 max_token_count: tiktoken_rs::model::get_context_size(&model),
981 pending_token_count: Task::ready(None),
982 model,
983 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
984 pending_save: Task::ready(Ok(())),
985 path: Some(path),
986 api_key,
987 buffer,
988 };
989 this.count_remaining_tokens(cx);
990 this
991 }
992
993 fn handle_buffer_event(
994 &mut self,
995 _: ModelHandle<Buffer>,
996 event: &language::Event,
997 cx: &mut ModelContext<Self>,
998 ) {
999 match event {
1000 language::Event::Edited => {
1001 self.count_remaining_tokens(cx);
1002 cx.emit(ConversationEvent::MessagesEdited);
1003 }
1004 _ => {}
1005 }
1006 }
1007
1008 fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
1009 let messages = self
1010 .messages(cx)
1011 .into_iter()
1012 .filter_map(|message| {
1013 Some(tiktoken_rs::ChatCompletionRequestMessage {
1014 role: match message.role {
1015 Role::User => "user".into(),
1016 Role::Assistant => "assistant".into(),
1017 Role::System => "system".into(),
1018 },
1019 content: self
1020 .buffer
1021 .read(cx)
1022 .text_for_range(message.offset_range)
1023 .collect(),
1024 name: None,
1025 })
1026 })
1027 .collect::<Vec<_>>();
1028 let model = self.model.clone();
1029 self.pending_token_count = cx.spawn_weak(|this, mut cx| {
1030 async move {
1031 cx.background().timer(Duration::from_millis(200)).await;
1032 let token_count = cx
1033 .background()
1034 .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
1035 .await?;
1036
1037 this.upgrade(&cx)
1038 .ok_or_else(|| anyhow!("conversation was dropped"))?
1039 .update(&mut cx, |this, cx| {
1040 this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
1041 this.token_count = Some(token_count);
1042 cx.notify()
1043 });
1044 anyhow::Ok(())
1045 }
1046 .log_err()
1047 });
1048 }
1049
1050 fn remaining_tokens(&self) -> Option<isize> {
1051 Some(self.max_token_count as isize - self.token_count? as isize)
1052 }
1053
1054 fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
1055 self.model = model;
1056 self.count_remaining_tokens(cx);
1057 cx.notify();
1058 }
1059
1060 fn assist(
1061 &mut self,
1062 selected_messages: HashSet<MessageId>,
1063 cx: &mut ModelContext<Self>,
1064 ) -> Vec<MessageAnchor> {
1065 let mut user_messages = Vec::new();
1066 let mut tasks = Vec::new();
1067
1068 let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
1069 message
1070 .start
1071 .is_valid(self.buffer.read(cx))
1072 .then_some(message.id)
1073 });
1074
1075 for selected_message_id in selected_messages {
1076 let selected_message_role =
1077 if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
1078 metadata.role
1079 } else {
1080 continue;
1081 };
1082
1083 if selected_message_role == Role::Assistant {
1084 if let Some(user_message) = self.insert_message_after(
1085 selected_message_id,
1086 Role::User,
1087 MessageStatus::Done,
1088 cx,
1089 ) {
1090 user_messages.push(user_message);
1091 } else {
1092 continue;
1093 }
1094 } else {
1095 let request = OpenAIRequest {
1096 model: self.model.clone(),
1097 messages: self
1098 .messages(cx)
1099 .filter(|message| matches!(message.status, MessageStatus::Done))
1100 .flat_map(|message| {
1101 let mut system_message = None;
1102 if message.id == selected_message_id {
1103 system_message = Some(RequestMessage {
1104 role: Role::System,
1105 content: concat!(
1106 "Treat the following messages as additional knowledge you have learned about, ",
1107 "but act as if they were not part of this conversation. That is, treat them ",
1108 "as if the user didn't see them and couldn't possibly inquire about them."
1109 ).into()
1110 });
1111 }
1112
1113 Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
1114 })
1115 .chain(Some(RequestMessage {
1116 role: Role::System,
1117 content: format!(
1118 "Direct your reply to message with id {}. Do not include a [Message X] header.",
1119 selected_message_id.0
1120 ),
1121 }))
1122 .collect(),
1123 stream: true,
1124 };
1125
1126 let Some(api_key) = self.api_key.borrow().clone() else { continue };
1127 let stream = stream_completion(api_key, cx.background().clone(), request);
1128 let assistant_message = self
1129 .insert_message_after(
1130 selected_message_id,
1131 Role::Assistant,
1132 MessageStatus::Pending,
1133 cx,
1134 )
1135 .unwrap();
1136
1137 // Queue up the user's next reply
1138 if Some(selected_message_id) == last_message_id {
1139 let user_message = self
1140 .insert_message_after(
1141 assistant_message.id,
1142 Role::User,
1143 MessageStatus::Done,
1144 cx,
1145 )
1146 .unwrap();
1147 user_messages.push(user_message);
1148 }
1149
1150 tasks.push(cx.spawn_weak({
1151 |this, mut cx| async move {
1152 let assistant_message_id = assistant_message.id;
1153 let stream_completion = async {
1154 let mut messages = stream.await?;
1155
1156 while let Some(message) = messages.next().await {
1157 let mut message = message?;
1158 if let Some(choice) = message.choices.pop() {
1159 this.upgrade(&cx)
1160 .ok_or_else(|| anyhow!("conversation was dropped"))?
1161 .update(&mut cx, |this, cx| {
1162 let text: Arc<str> = choice.delta.content?.into();
1163 let message_ix = this.message_anchors.iter().position(
1164 |message| message.id == assistant_message_id,
1165 )?;
1166 this.buffer.update(cx, |buffer, cx| {
1167 let offset = this.message_anchors[message_ix + 1..]
1168 .iter()
1169 .find(|message| message.start.is_valid(buffer))
1170 .map_or(buffer.len(), |message| {
1171 message
1172 .start
1173 .to_offset(buffer)
1174 .saturating_sub(1)
1175 });
1176 buffer.edit([(offset..offset, text)], None, cx);
1177 });
1178 cx.emit(ConversationEvent::StreamedCompletion);
1179
1180 Some(())
1181 });
1182 }
1183 smol::future::yield_now().await;
1184 }
1185
1186 this.upgrade(&cx)
1187 .ok_or_else(|| anyhow!("conversation was dropped"))?
1188 .update(&mut cx, |this, cx| {
1189 this.pending_completions.retain(|completion| {
1190 completion.id != this.completion_count
1191 });
1192 this.summarize(cx);
1193 });
1194
1195 anyhow::Ok(())
1196 };
1197
1198 let result = stream_completion.await;
1199 if let Some(this) = this.upgrade(&cx) {
1200 this.update(&mut cx, |this, cx| {
1201 if let Some(metadata) =
1202 this.messages_metadata.get_mut(&assistant_message.id)
1203 {
1204 match result {
1205 Ok(_) => {
1206 metadata.status = MessageStatus::Done;
1207 }
1208 Err(error) => {
1209 metadata.status = MessageStatus::Error(
1210 error.to_string().trim().into(),
1211 );
1212 }
1213 }
1214 cx.notify();
1215 }
1216 });
1217 }
1218 }
1219 }));
1220 }
1221 }
1222
1223 if !tasks.is_empty() {
1224 self.pending_completions.push(PendingCompletion {
1225 id: post_inc(&mut self.completion_count),
1226 _tasks: tasks,
1227 });
1228 }
1229
1230 user_messages
1231 }
1232
1233 fn cancel_last_assist(&mut self) -> bool {
1234 self.pending_completions.pop().is_some()
1235 }
1236
1237 fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
1238 for id in ids {
1239 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
1240 metadata.role.cycle();
1241 cx.emit(ConversationEvent::MessagesEdited);
1242 cx.notify();
1243 }
1244 }
1245 }
1246
1247 fn insert_message_after(
1248 &mut self,
1249 message_id: MessageId,
1250 role: Role,
1251 status: MessageStatus,
1252 cx: &mut ModelContext<Self>,
1253 ) -> Option<MessageAnchor> {
1254 if let Some(prev_message_ix) = self
1255 .message_anchors
1256 .iter()
1257 .position(|message| message.id == message_id)
1258 {
1259 // Find the next valid message after the one we were given.
1260 let mut next_message_ix = prev_message_ix + 1;
1261 while let Some(next_message) = self.message_anchors.get(next_message_ix) {
1262 if next_message.start.is_valid(self.buffer.read(cx)) {
1263 break;
1264 }
1265 next_message_ix += 1;
1266 }
1267
1268 let start = self.buffer.update(cx, |buffer, cx| {
1269 let offset = self
1270 .message_anchors
1271 .get(next_message_ix)
1272 .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
1273 buffer.edit([(offset..offset, "\n")], None, cx);
1274 buffer.anchor_before(offset + 1)
1275 });
1276 let message = MessageAnchor {
1277 id: MessageId(post_inc(&mut self.next_message_id.0)),
1278 start,
1279 };
1280 self.message_anchors
1281 .insert(next_message_ix, message.clone());
1282 self.messages_metadata.insert(
1283 message.id,
1284 MessageMetadata {
1285 role,
1286 sent_at: Local::now(),
1287 status,
1288 },
1289 );
1290 cx.emit(ConversationEvent::MessagesEdited);
1291 Some(message)
1292 } else {
1293 None
1294 }
1295 }
1296
1297 fn split_message(
1298 &mut self,
1299 range: Range<usize>,
1300 cx: &mut ModelContext<Self>,
1301 ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
1302 let start_message = self.message_for_offset(range.start, cx);
1303 let end_message = self.message_for_offset(range.end, cx);
1304 if let Some((start_message, end_message)) = start_message.zip(end_message) {
1305 // Prevent splitting when range spans multiple messages.
1306 if start_message.id != end_message.id {
1307 return (None, None);
1308 }
1309
1310 let message = start_message;
1311 let role = message.role;
1312 let mut edited_buffer = false;
1313
1314 let mut suffix_start = None;
1315 if range.start > message.offset_range.start && range.end < message.offset_range.end - 1
1316 {
1317 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
1318 suffix_start = Some(range.end + 1);
1319 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
1320 suffix_start = Some(range.end);
1321 }
1322 }
1323
1324 let suffix = if let Some(suffix_start) = suffix_start {
1325 MessageAnchor {
1326 id: MessageId(post_inc(&mut self.next_message_id.0)),
1327 start: self.buffer.read(cx).anchor_before(suffix_start),
1328 }
1329 } else {
1330 self.buffer.update(cx, |buffer, cx| {
1331 buffer.edit([(range.end..range.end, "\n")], None, cx);
1332 });
1333 edited_buffer = true;
1334 MessageAnchor {
1335 id: MessageId(post_inc(&mut self.next_message_id.0)),
1336 start: self.buffer.read(cx).anchor_before(range.end + 1),
1337 }
1338 };
1339
1340 self.message_anchors
1341 .insert(message.index_range.end + 1, suffix.clone());
1342 self.messages_metadata.insert(
1343 suffix.id,
1344 MessageMetadata {
1345 role,
1346 sent_at: Local::now(),
1347 status: MessageStatus::Done,
1348 },
1349 );
1350
1351 let new_messages =
1352 if range.start == range.end || range.start == message.offset_range.start {
1353 (None, Some(suffix))
1354 } else {
1355 let mut prefix_end = None;
1356 if range.start > message.offset_range.start
1357 && range.end < message.offset_range.end - 1
1358 {
1359 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
1360 prefix_end = Some(range.start + 1);
1361 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
1362 == Some('\n')
1363 {
1364 prefix_end = Some(range.start);
1365 }
1366 }
1367
1368 let selection = if let Some(prefix_end) = prefix_end {
1369 cx.emit(ConversationEvent::MessagesEdited);
1370 MessageAnchor {
1371 id: MessageId(post_inc(&mut self.next_message_id.0)),
1372 start: self.buffer.read(cx).anchor_before(prefix_end),
1373 }
1374 } else {
1375 self.buffer.update(cx, |buffer, cx| {
1376 buffer.edit([(range.start..range.start, "\n")], None, cx)
1377 });
1378 edited_buffer = true;
1379 MessageAnchor {
1380 id: MessageId(post_inc(&mut self.next_message_id.0)),
1381 start: self.buffer.read(cx).anchor_before(range.end + 1),
1382 }
1383 };
1384
1385 self.message_anchors
1386 .insert(message.index_range.end + 1, selection.clone());
1387 self.messages_metadata.insert(
1388 selection.id,
1389 MessageMetadata {
1390 role,
1391 sent_at: Local::now(),
1392 status: MessageStatus::Done,
1393 },
1394 );
1395 (Some(selection), Some(suffix))
1396 };
1397
1398 if !edited_buffer {
1399 cx.emit(ConversationEvent::MessagesEdited);
1400 }
1401 new_messages
1402 } else {
1403 (None, None)
1404 }
1405 }
1406
1407 fn summarize(&mut self, cx: &mut ModelContext<Self>) {
1408 if self.message_anchors.len() >= 2 && self.summary.is_none() {
1409 let api_key = self.api_key.borrow().clone();
1410 if let Some(api_key) = api_key {
1411 let messages = self
1412 .messages(cx)
1413 .take(2)
1414 .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
1415 .chain(Some(RequestMessage {
1416 role: Role::User,
1417 content:
1418 "Summarize the conversation into a short title without punctuation"
1419 .into(),
1420 }));
1421 let request = OpenAIRequest {
1422 model: self.model.clone(),
1423 messages: messages.collect(),
1424 stream: true,
1425 };
1426
1427 let stream = stream_completion(api_key, cx.background().clone(), request);
1428 self.pending_summary = cx.spawn(|this, mut cx| {
1429 async move {
1430 let mut messages = stream.await?;
1431
1432 while let Some(message) = messages.next().await {
1433 let mut message = message?;
1434 if let Some(choice) = message.choices.pop() {
1435 let text = choice.delta.content.unwrap_or_default();
1436 this.update(&mut cx, |this, cx| {
1437 this.summary
1438 .get_or_insert(Default::default())
1439 .text
1440 .push_str(&text);
1441 cx.emit(ConversationEvent::SummaryChanged);
1442 });
1443 }
1444 }
1445
1446 this.update(&mut cx, |this, cx| {
1447 if let Some(summary) = this.summary.as_mut() {
1448 summary.done = true;
1449 cx.emit(ConversationEvent::SummaryChanged);
1450 }
1451 });
1452
1453 anyhow::Ok(())
1454 }
1455 .log_err()
1456 });
1457 }
1458 }
1459 }
1460
1461 fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
1462 self.messages_for_offsets([offset], cx).pop()
1463 }
1464
1465 fn messages_for_offsets(
1466 &self,
1467 offsets: impl IntoIterator<Item = usize>,
1468 cx: &AppContext,
1469 ) -> Vec<Message> {
1470 let mut result = Vec::new();
1471
1472 let mut messages = self.messages(cx).peekable();
1473 let mut offsets = offsets.into_iter().peekable();
1474 let mut current_message = messages.next();
1475 while let Some(offset) = offsets.next() {
1476 // Locate the message that contains the offset.
1477 while current_message.as_ref().map_or(false, |message| {
1478 !message.offset_range.contains(&offset) && messages.peek().is_some()
1479 }) {
1480 current_message = messages.next();
1481 }
1482 let Some(message) = current_message.as_ref() else { break };
1483
1484 // Skip offsets that are in the same message.
1485 while offsets.peek().map_or(false, |offset| {
1486 message.offset_range.contains(offset) || messages.peek().is_none()
1487 }) {
1488 offsets.next();
1489 }
1490
1491 result.push(message.clone());
1492 }
1493 result
1494 }
1495
1496 fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
1497 let buffer = self.buffer.read(cx);
1498 let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
1499 iter::from_fn(move || {
1500 while let Some((start_ix, message_anchor)) = message_anchors.next() {
1501 let metadata = self.messages_metadata.get(&message_anchor.id)?;
1502 let message_start = message_anchor.start.to_offset(buffer);
1503 let mut message_end = None;
1504 let mut end_ix = start_ix;
1505 while let Some((_, next_message)) = message_anchors.peek() {
1506 if next_message.start.is_valid(buffer) {
1507 message_end = Some(next_message.start);
1508 break;
1509 } else {
1510 end_ix += 1;
1511 message_anchors.next();
1512 }
1513 }
1514 let message_end = message_end
1515 .unwrap_or(language::Anchor::MAX)
1516 .to_offset(buffer);
1517 return Some(Message {
1518 index_range: start_ix..end_ix,
1519 offset_range: message_start..message_end,
1520 id: message_anchor.id,
1521 anchor: message_anchor.start,
1522 role: metadata.role,
1523 sent_at: metadata.sent_at,
1524 status: metadata.status.clone(),
1525 });
1526 }
1527 None
1528 })
1529 }
1530
1531 fn save(
1532 &mut self,
1533 debounce: Option<Duration>,
1534 fs: Arc<dyn Fs>,
1535 cx: &mut ModelContext<Conversation>,
1536 ) {
1537 self.pending_save = cx.spawn(|this, mut cx| async move {
1538 if let Some(debounce) = debounce {
1539 cx.background().timer(debounce).await;
1540 }
1541
1542 let (old_path, summary) = this.read_with(&cx, |this, _| {
1543 let path = this.path.clone();
1544 let summary = if let Some(summary) = this.summary.as_ref() {
1545 if summary.done {
1546 Some(summary.text.clone())
1547 } else {
1548 None
1549 }
1550 } else {
1551 None
1552 };
1553 (path, summary)
1554 });
1555
1556 if let Some(summary) = summary {
1557 let conversation = this.read_with(&cx, |this, cx| this.serialize(cx));
1558 let path = if let Some(old_path) = old_path {
1559 old_path
1560 } else {
1561 let mut discriminant = 1;
1562 let mut new_path;
1563 loop {
1564 new_path = CONVERSATIONS_DIR.join(&format!(
1565 "{} - {}.zed.json",
1566 summary.trim(),
1567 discriminant
1568 ));
1569 if fs.is_file(&new_path).await {
1570 discriminant += 1;
1571 } else {
1572 break;
1573 }
1574 }
1575 new_path
1576 };
1577
1578 fs.create_dir(CONVERSATIONS_DIR.as_ref()).await?;
1579 fs.atomic_write(path.clone(), serde_json::to_string(&conversation).unwrap())
1580 .await?;
1581 this.update(&mut cx, |this, _| this.path = Some(path));
1582 }
1583
1584 Ok(())
1585 });
1586 }
1587}
1588
1589struct PendingCompletion {
1590 id: usize,
1591 _tasks: Vec<Task<()>>,
1592}
1593
1594enum ConversationEditorEvent {
1595 TabContentChanged,
1596}
1597
1598#[derive(Copy, Clone, Debug, PartialEq)]
1599struct ScrollPosition {
1600 offset_before_cursor: Vector2F,
1601 cursor: Anchor,
1602}
1603
1604struct ConversationEditor {
1605 conversation: ModelHandle<Conversation>,
1606 fs: Arc<dyn Fs>,
1607 editor: ViewHandle<Editor>,
1608 blocks: HashSet<BlockId>,
1609 scroll_position: Option<ScrollPosition>,
1610 _subscriptions: Vec<Subscription>,
1611}
1612
1613impl ConversationEditor {
1614 fn new(
1615 api_key: Rc<RefCell<Option<String>>>,
1616 language_registry: Arc<LanguageRegistry>,
1617 fs: Arc<dyn Fs>,
1618 cx: &mut ViewContext<Self>,
1619 ) -> Self {
1620 let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
1621 Self::for_conversation(conversation, fs, cx)
1622 }
1623
1624 fn for_conversation(
1625 conversation: ModelHandle<Conversation>,
1626 fs: Arc<dyn Fs>,
1627 cx: &mut ViewContext<Self>,
1628 ) -> Self {
1629 let editor = cx.add_view(|cx| {
1630 let mut editor = Editor::for_buffer(conversation.read(cx).buffer.clone(), None, cx);
1631 editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
1632 editor.set_show_gutter(false, cx);
1633 editor.set_show_wrap_guides(false, cx);
1634 editor
1635 });
1636
1637 let _subscriptions = vec![
1638 cx.observe(&conversation, |_, _, cx| cx.notify()),
1639 cx.subscribe(&conversation, Self::handle_conversation_event),
1640 cx.subscribe(&editor, Self::handle_editor_event),
1641 ];
1642
1643 let mut this = Self {
1644 conversation,
1645 editor,
1646 blocks: Default::default(),
1647 scroll_position: None,
1648 fs,
1649 _subscriptions,
1650 };
1651 this.update_message_headers(cx);
1652 this
1653 }
1654
1655 fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
1656 let cursors = self.cursors(cx);
1657
1658 let user_messages = self.conversation.update(cx, |conversation, cx| {
1659 let selected_messages = conversation
1660 .messages_for_offsets(cursors, cx)
1661 .into_iter()
1662 .map(|message| message.id)
1663 .collect();
1664 conversation.assist(selected_messages, cx)
1665 });
1666 let new_selections = user_messages
1667 .iter()
1668 .map(|message| {
1669 let cursor = message
1670 .start
1671 .to_offset(self.conversation.read(cx).buffer.read(cx));
1672 cursor..cursor
1673 })
1674 .collect::<Vec<_>>();
1675 if !new_selections.is_empty() {
1676 self.editor.update(cx, |editor, cx| {
1677 editor.change_selections(
1678 Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
1679 cx,
1680 |selections| selections.select_ranges(new_selections),
1681 );
1682 });
1683 // Avoid scrolling to the new cursor position so the assistant's output is stable.
1684 cx.defer(|this, _| this.scroll_position = None);
1685 }
1686 }
1687
1688 fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
1689 if !self
1690 .conversation
1691 .update(cx, |conversation, _| conversation.cancel_last_assist())
1692 {
1693 cx.propagate_action();
1694 }
1695 }
1696
1697 fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
1698 let cursors = self.cursors(cx);
1699 self.conversation.update(cx, |conversation, cx| {
1700 let messages = conversation
1701 .messages_for_offsets(cursors, cx)
1702 .into_iter()
1703 .map(|message| message.id)
1704 .collect();
1705 conversation.cycle_message_roles(messages, cx)
1706 });
1707 }
1708
1709 fn cursors(&self, cx: &AppContext) -> Vec<usize> {
1710 let selections = self.editor.read(cx).selections.all::<usize>(cx);
1711 selections
1712 .into_iter()
1713 .map(|selection| selection.head())
1714 .collect()
1715 }
1716
1717 fn handle_conversation_event(
1718 &mut self,
1719 _: ModelHandle<Conversation>,
1720 event: &ConversationEvent,
1721 cx: &mut ViewContext<Self>,
1722 ) {
1723 match event {
1724 ConversationEvent::MessagesEdited => {
1725 self.update_message_headers(cx);
1726 self.conversation.update(cx, |conversation, cx| {
1727 conversation.save(Some(Duration::from_millis(500)), self.fs.clone(), cx);
1728 });
1729 }
1730 ConversationEvent::SummaryChanged => {
1731 cx.emit(ConversationEditorEvent::TabContentChanged);
1732 self.conversation.update(cx, |conversation, cx| {
1733 conversation.save(None, self.fs.clone(), cx);
1734 });
1735 }
1736 ConversationEvent::StreamedCompletion => {
1737 self.editor.update(cx, |editor, cx| {
1738 if let Some(scroll_position) = self.scroll_position {
1739 let snapshot = editor.snapshot(cx);
1740 let cursor_point = scroll_position.cursor.to_display_point(&snapshot);
1741 let scroll_top =
1742 cursor_point.row() as f32 - scroll_position.offset_before_cursor.y();
1743 editor.set_scroll_position(
1744 vec2f(scroll_position.offset_before_cursor.x(), scroll_top),
1745 cx,
1746 );
1747 }
1748 });
1749 }
1750 }
1751 }
1752
1753 fn handle_editor_event(
1754 &mut self,
1755 _: ViewHandle<Editor>,
1756 event: &editor::Event,
1757 cx: &mut ViewContext<Self>,
1758 ) {
1759 match event {
1760 editor::Event::ScrollPositionChanged { autoscroll, .. } => {
1761 let cursor_scroll_position = self.cursor_scroll_position(cx);
1762 if *autoscroll {
1763 self.scroll_position = cursor_scroll_position;
1764 } else if self.scroll_position != cursor_scroll_position {
1765 self.scroll_position = None;
1766 }
1767 }
1768 editor::Event::SelectionsChanged { .. } => {
1769 self.scroll_position = self.cursor_scroll_position(cx);
1770 }
1771 _ => {}
1772 }
1773 }
1774
1775 fn cursor_scroll_position(&self, cx: &mut ViewContext<Self>) -> Option<ScrollPosition> {
1776 self.editor.update(cx, |editor, cx| {
1777 let snapshot = editor.snapshot(cx);
1778 let cursor = editor.selections.newest_anchor().head();
1779 let cursor_row = cursor.to_display_point(&snapshot.display_snapshot).row() as f32;
1780 let scroll_position = editor
1781 .scroll_manager
1782 .anchor()
1783 .scroll_position(&snapshot.display_snapshot);
1784
1785 let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1786 if (scroll_position.y()..scroll_bottom).contains(&cursor_row) {
1787 Some(ScrollPosition {
1788 cursor,
1789 offset_before_cursor: vec2f(
1790 scroll_position.x(),
1791 cursor_row - scroll_position.y(),
1792 ),
1793 })
1794 } else {
1795 None
1796 }
1797 })
1798 }
1799
1800 fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
1801 self.editor.update(cx, |editor, cx| {
1802 let buffer = editor.buffer().read(cx).snapshot(cx);
1803 let excerpt_id = *buffer.as_singleton().unwrap().0;
1804 let old_blocks = std::mem::take(&mut self.blocks);
1805 let new_blocks = self
1806 .conversation
1807 .read(cx)
1808 .messages(cx)
1809 .map(|message| BlockProperties {
1810 position: buffer.anchor_in_excerpt(excerpt_id, message.anchor),
1811 height: 2,
1812 style: BlockStyle::Sticky,
1813 render: Arc::new({
1814 let conversation = self.conversation.clone();
1815 // let metadata = message.metadata.clone();
1816 // let message = message.clone();
1817 move |cx| {
1818 enum Sender {}
1819 enum ErrorTooltip {}
1820
1821 let theme = theme::current(cx);
1822 let style = &theme.assistant;
1823 let message_id = message.id;
1824 let sender = MouseEventHandler::new::<Sender, _>(
1825 message_id.0,
1826 cx,
1827 |state, _| match message.role {
1828 Role::User => {
1829 let style = style.user_sender.style_for(state);
1830 Label::new("You", style.text.clone())
1831 .contained()
1832 .with_style(style.container)
1833 }
1834 Role::Assistant => {
1835 let style = style.assistant_sender.style_for(state);
1836 Label::new("Assistant", style.text.clone())
1837 .contained()
1838 .with_style(style.container)
1839 }
1840 Role::System => {
1841 let style = style.system_sender.style_for(state);
1842 Label::new("System", style.text.clone())
1843 .contained()
1844 .with_style(style.container)
1845 }
1846 },
1847 )
1848 .with_cursor_style(CursorStyle::PointingHand)
1849 .on_down(MouseButton::Left, {
1850 let conversation = conversation.clone();
1851 move |_, _, cx| {
1852 conversation.update(cx, |conversation, cx| {
1853 conversation.cycle_message_roles(
1854 HashSet::from_iter(Some(message_id)),
1855 cx,
1856 )
1857 })
1858 }
1859 });
1860
1861 Flex::row()
1862 .with_child(sender.aligned())
1863 .with_child(
1864 Label::new(
1865 message.sent_at.format("%I:%M%P").to_string(),
1866 style.sent_at.text.clone(),
1867 )
1868 .contained()
1869 .with_style(style.sent_at.container)
1870 .aligned(),
1871 )
1872 .with_children(
1873 if let MessageStatus::Error(error) = &message.status {
1874 Some(
1875 Svg::new("icons/circle_x_mark_12.svg")
1876 .with_color(style.error_icon.color)
1877 .constrained()
1878 .with_width(style.error_icon.width)
1879 .contained()
1880 .with_style(style.error_icon.container)
1881 .with_tooltip::<ErrorTooltip>(
1882 message_id.0,
1883 error.to_string(),
1884 None,
1885 theme.tooltip.clone(),
1886 cx,
1887 )
1888 .aligned(),
1889 )
1890 } else {
1891 None
1892 },
1893 )
1894 .aligned()
1895 .left()
1896 .contained()
1897 .with_style(style.message_header)
1898 .into_any()
1899 }
1900 }),
1901 disposition: BlockDisposition::Above,
1902 })
1903 .collect::<Vec<_>>();
1904
1905 editor.remove_blocks(old_blocks, None, cx);
1906 let ids = editor.insert_blocks(new_blocks, None, cx);
1907 self.blocks = HashSet::from_iter(ids);
1908 });
1909 }
1910
1911 fn quote_selection(
1912 workspace: &mut Workspace,
1913 _: &QuoteSelection,
1914 cx: &mut ViewContext<Workspace>,
1915 ) {
1916 let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1917 return;
1918 };
1919 let Some(editor) = workspace.active_item(cx).and_then(|item| item.act_as::<Editor>(cx)) else {
1920 return;
1921 };
1922
1923 let text = editor.read_with(cx, |editor, cx| {
1924 let range = editor.selections.newest::<usize>(cx).range();
1925 let buffer = editor.buffer().read(cx).snapshot(cx);
1926 let start_language = buffer.language_at(range.start);
1927 let end_language = buffer.language_at(range.end);
1928 let language_name = if start_language == end_language {
1929 start_language.map(|language| language.name())
1930 } else {
1931 None
1932 };
1933 let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1934
1935 let selected_text = buffer.text_for_range(range).collect::<String>();
1936 if selected_text.is_empty() {
1937 None
1938 } else {
1939 Some(if language_name == "markdown" {
1940 selected_text
1941 .lines()
1942 .map(|line| format!("> {}", line))
1943 .collect::<Vec<_>>()
1944 .join("\n")
1945 } else {
1946 format!("```{language_name}\n{selected_text}\n```")
1947 })
1948 }
1949 });
1950
1951 // Activate the panel
1952 if !panel.read(cx).has_focus(cx) {
1953 workspace.toggle_panel_focus::<AssistantPanel>(cx);
1954 }
1955
1956 if let Some(text) = text {
1957 panel.update(cx, |panel, cx| {
1958 let conversation = panel
1959 .active_editor()
1960 .cloned()
1961 .unwrap_or_else(|| panel.new_conversation(cx));
1962 conversation.update(cx, |conversation, cx| {
1963 conversation
1964 .editor
1965 .update(cx, |editor, cx| editor.insert(&text, cx))
1966 });
1967 });
1968 }
1969 }
1970
1971 fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1972 let editor = self.editor.read(cx);
1973 let conversation = self.conversation.read(cx);
1974 if editor.selections.count() == 1 {
1975 let selection = editor.selections.newest::<usize>(cx);
1976 let mut copied_text = String::new();
1977 let mut spanned_messages = 0;
1978 for message in conversation.messages(cx) {
1979 if message.offset_range.start >= selection.range().end {
1980 break;
1981 } else if message.offset_range.end >= selection.range().start {
1982 let range = cmp::max(message.offset_range.start, selection.range().start)
1983 ..cmp::min(message.offset_range.end, selection.range().end);
1984 if !range.is_empty() {
1985 spanned_messages += 1;
1986 write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
1987 for chunk in conversation.buffer.read(cx).text_for_range(range) {
1988 copied_text.push_str(&chunk);
1989 }
1990 copied_text.push('\n');
1991 }
1992 }
1993 }
1994
1995 if spanned_messages > 1 {
1996 cx.platform()
1997 .write_to_clipboard(ClipboardItem::new(copied_text));
1998 return;
1999 }
2000 }
2001
2002 cx.propagate_action();
2003 }
2004
2005 fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
2006 self.conversation.update(cx, |conversation, cx| {
2007 let selections = self.editor.read(cx).selections.disjoint_anchors();
2008 for selection in selections.into_iter() {
2009 let buffer = self.editor.read(cx).buffer().read(cx).snapshot(cx);
2010 let range = selection
2011 .map(|endpoint| endpoint.to_offset(&buffer))
2012 .range();
2013 conversation.split_message(range, cx);
2014 }
2015 });
2016 }
2017
2018 fn save(&mut self, _: &Save, cx: &mut ViewContext<Self>) {
2019 self.conversation.update(cx, |conversation, cx| {
2020 conversation.save(None, self.fs.clone(), cx)
2021 });
2022 }
2023
2024 fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
2025 self.conversation.update(cx, |conversation, cx| {
2026 let new_model = match conversation.model.as_str() {
2027 "gpt-4-0613" => "gpt-3.5-turbo-0613",
2028 _ => "gpt-4-0613",
2029 };
2030 conversation.set_model(new_model.into(), cx);
2031 });
2032 }
2033
2034 fn title(&self, cx: &AppContext) -> String {
2035 self.conversation
2036 .read(cx)
2037 .summary
2038 .as_ref()
2039 .map(|summary| summary.text.clone())
2040 .unwrap_or_else(|| "New Conversation".into())
2041 }
2042
2043 fn render_current_model(
2044 &self,
2045 style: &AssistantStyle,
2046 cx: &mut ViewContext<Self>,
2047 ) -> impl Element<Self> {
2048 enum Model {}
2049
2050 MouseEventHandler::new::<Model, _>(0, cx, |state, cx| {
2051 let style = style.model.style_for(state);
2052 Label::new(self.conversation.read(cx).model.clone(), style.text.clone())
2053 .contained()
2054 .with_style(style.container)
2055 })
2056 .with_cursor_style(CursorStyle::PointingHand)
2057 .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx))
2058 }
2059
2060 fn render_remaining_tokens(
2061 &self,
2062 style: &AssistantStyle,
2063 cx: &mut ViewContext<Self>,
2064 ) -> Option<impl Element<Self>> {
2065 let remaining_tokens = self.conversation.read(cx).remaining_tokens()?;
2066 let remaining_tokens_style = if remaining_tokens <= 0 {
2067 &style.no_remaining_tokens
2068 } else if remaining_tokens <= 500 {
2069 &style.low_remaining_tokens
2070 } else {
2071 &style.remaining_tokens
2072 };
2073 Some(
2074 Label::new(
2075 remaining_tokens.to_string(),
2076 remaining_tokens_style.text.clone(),
2077 )
2078 .contained()
2079 .with_style(remaining_tokens_style.container),
2080 )
2081 }
2082}
2083
2084impl Entity for ConversationEditor {
2085 type Event = ConversationEditorEvent;
2086}
2087
2088impl View for ConversationEditor {
2089 fn ui_name() -> &'static str {
2090 "ConversationEditor"
2091 }
2092
2093 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
2094 let theme = &theme::current(cx).assistant;
2095 Stack::new()
2096 .with_child(
2097 ChildView::new(&self.editor, cx)
2098 .contained()
2099 .with_style(theme.container),
2100 )
2101 .with_child(
2102 Flex::row()
2103 .with_child(self.render_current_model(theme, cx))
2104 .with_children(self.render_remaining_tokens(theme, cx))
2105 .aligned()
2106 .top()
2107 .right(),
2108 )
2109 .into_any()
2110 }
2111
2112 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
2113 if cx.is_self_focused() {
2114 cx.focus(&self.editor);
2115 }
2116 }
2117}
2118
2119#[derive(Clone, Debug)]
2120struct MessageAnchor {
2121 id: MessageId,
2122 start: language::Anchor,
2123}
2124
2125#[derive(Clone, Debug)]
2126pub struct Message {
2127 offset_range: Range<usize>,
2128 index_range: Range<usize>,
2129 id: MessageId,
2130 anchor: language::Anchor,
2131 role: Role,
2132 sent_at: DateTime<Local>,
2133 status: MessageStatus,
2134}
2135
2136impl Message {
2137 fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
2138 let mut content = format!("[Message {}]\n", self.id.0).to_string();
2139 content.extend(buffer.text_for_range(self.offset_range.clone()));
2140 RequestMessage {
2141 role: self.role,
2142 content: content.trim_end().into(),
2143 }
2144 }
2145}
2146
2147async fn stream_completion(
2148 api_key: String,
2149 executor: Arc<Background>,
2150 mut request: OpenAIRequest,
2151) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
2152 request.stream = true;
2153
2154 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
2155
2156 let json_data = serde_json::to_string(&request)?;
2157 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
2158 .header("Content-Type", "application/json")
2159 .header("Authorization", format!("Bearer {}", api_key))
2160 .body(json_data)?
2161 .send_async()
2162 .await?;
2163
2164 let status = response.status();
2165 if status == StatusCode::OK {
2166 executor
2167 .spawn(async move {
2168 let mut lines = BufReader::new(response.body_mut()).lines();
2169
2170 fn parse_line(
2171 line: Result<String, io::Error>,
2172 ) -> Result<Option<OpenAIResponseStreamEvent>> {
2173 if let Some(data) = line?.strip_prefix("data: ") {
2174 let event = serde_json::from_str(&data)?;
2175 Ok(Some(event))
2176 } else {
2177 Ok(None)
2178 }
2179 }
2180
2181 while let Some(line) = lines.next().await {
2182 if let Some(event) = parse_line(line).transpose() {
2183 let done = event.as_ref().map_or(false, |event| {
2184 event
2185 .choices
2186 .last()
2187 .map_or(false, |choice| choice.finish_reason.is_some())
2188 });
2189 if tx.unbounded_send(event).is_err() {
2190 break;
2191 }
2192
2193 if done {
2194 break;
2195 }
2196 }
2197 }
2198
2199 anyhow::Ok(())
2200 })
2201 .detach();
2202
2203 Ok(rx)
2204 } else {
2205 let mut body = String::new();
2206 response.body_mut().read_to_string(&mut body).await?;
2207
2208 #[derive(Deserialize)]
2209 struct OpenAIResponse {
2210 error: OpenAIError,
2211 }
2212
2213 #[derive(Deserialize)]
2214 struct OpenAIError {
2215 message: String,
2216 }
2217
2218 match serde_json::from_str::<OpenAIResponse>(&body) {
2219 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
2220 "Failed to connect to OpenAI API: {}",
2221 response.error.message,
2222 )),
2223
2224 _ => Err(anyhow!(
2225 "Failed to connect to OpenAI API: {} {}",
2226 response.status(),
2227 body,
2228 )),
2229 }
2230 }
2231}
2232
2233#[cfg(test)]
2234mod tests {
2235 use super::*;
2236 use crate::MessageId;
2237 use gpui::AppContext;
2238
2239 #[gpui::test]
2240 fn test_inserting_and_removing_messages(cx: &mut AppContext) {
2241 let registry = Arc::new(LanguageRegistry::test());
2242 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2243 let buffer = conversation.read(cx).buffer.clone();
2244
2245 let message_1 = conversation.read(cx).message_anchors[0].clone();
2246 assert_eq!(
2247 messages(&conversation, cx),
2248 vec![(message_1.id, Role::User, 0..0)]
2249 );
2250
2251 let message_2 = conversation.update(cx, |conversation, cx| {
2252 conversation
2253 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
2254 .unwrap()
2255 });
2256 assert_eq!(
2257 messages(&conversation, cx),
2258 vec![
2259 (message_1.id, Role::User, 0..1),
2260 (message_2.id, Role::Assistant, 1..1)
2261 ]
2262 );
2263
2264 buffer.update(cx, |buffer, cx| {
2265 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
2266 });
2267 assert_eq!(
2268 messages(&conversation, cx),
2269 vec![
2270 (message_1.id, Role::User, 0..2),
2271 (message_2.id, Role::Assistant, 2..3)
2272 ]
2273 );
2274
2275 let message_3 = conversation.update(cx, |conversation, cx| {
2276 conversation
2277 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2278 .unwrap()
2279 });
2280 assert_eq!(
2281 messages(&conversation, cx),
2282 vec![
2283 (message_1.id, Role::User, 0..2),
2284 (message_2.id, Role::Assistant, 2..4),
2285 (message_3.id, Role::User, 4..4)
2286 ]
2287 );
2288
2289 let message_4 = conversation.update(cx, |conversation, cx| {
2290 conversation
2291 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2292 .unwrap()
2293 });
2294 assert_eq!(
2295 messages(&conversation, cx),
2296 vec![
2297 (message_1.id, Role::User, 0..2),
2298 (message_2.id, Role::Assistant, 2..4),
2299 (message_4.id, Role::User, 4..5),
2300 (message_3.id, Role::User, 5..5),
2301 ]
2302 );
2303
2304 buffer.update(cx, |buffer, cx| {
2305 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
2306 });
2307 assert_eq!(
2308 messages(&conversation, cx),
2309 vec![
2310 (message_1.id, Role::User, 0..2),
2311 (message_2.id, Role::Assistant, 2..4),
2312 (message_4.id, Role::User, 4..6),
2313 (message_3.id, Role::User, 6..7),
2314 ]
2315 );
2316
2317 // Deleting across message boundaries merges the messages.
2318 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
2319 assert_eq!(
2320 messages(&conversation, cx),
2321 vec![
2322 (message_1.id, Role::User, 0..3),
2323 (message_3.id, Role::User, 3..4),
2324 ]
2325 );
2326
2327 // Undoing the deletion should also undo the merge.
2328 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2329 assert_eq!(
2330 messages(&conversation, cx),
2331 vec![
2332 (message_1.id, Role::User, 0..2),
2333 (message_2.id, Role::Assistant, 2..4),
2334 (message_4.id, Role::User, 4..6),
2335 (message_3.id, Role::User, 6..7),
2336 ]
2337 );
2338
2339 // Redoing the deletion should also redo the merge.
2340 buffer.update(cx, |buffer, cx| buffer.redo(cx));
2341 assert_eq!(
2342 messages(&conversation, cx),
2343 vec![
2344 (message_1.id, Role::User, 0..3),
2345 (message_3.id, Role::User, 3..4),
2346 ]
2347 );
2348
2349 // Ensure we can still insert after a merged message.
2350 let message_5 = conversation.update(cx, |conversation, cx| {
2351 conversation
2352 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2353 .unwrap()
2354 });
2355 assert_eq!(
2356 messages(&conversation, cx),
2357 vec![
2358 (message_1.id, Role::User, 0..3),
2359 (message_5.id, Role::System, 3..4),
2360 (message_3.id, Role::User, 4..5)
2361 ]
2362 );
2363 }
2364
2365 #[gpui::test]
2366 fn test_message_splitting(cx: &mut AppContext) {
2367 let registry = Arc::new(LanguageRegistry::test());
2368 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2369 let buffer = conversation.read(cx).buffer.clone();
2370
2371 let message_1 = conversation.read(cx).message_anchors[0].clone();
2372 assert_eq!(
2373 messages(&conversation, cx),
2374 vec![(message_1.id, Role::User, 0..0)]
2375 );
2376
2377 buffer.update(cx, |buffer, cx| {
2378 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
2379 });
2380
2381 let (_, message_2) =
2382 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2383 let message_2 = message_2.unwrap();
2384
2385 // We recycle newlines in the middle of a split message
2386 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
2387 assert_eq!(
2388 messages(&conversation, cx),
2389 vec![
2390 (message_1.id, Role::User, 0..4),
2391 (message_2.id, Role::User, 4..16),
2392 ]
2393 );
2394
2395 let (_, message_3) =
2396 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2397 let message_3 = message_3.unwrap();
2398
2399 // We don't recycle newlines at the end of a split message
2400 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2401 assert_eq!(
2402 messages(&conversation, cx),
2403 vec![
2404 (message_1.id, Role::User, 0..4),
2405 (message_3.id, Role::User, 4..5),
2406 (message_2.id, Role::User, 5..17),
2407 ]
2408 );
2409
2410 let (_, message_4) =
2411 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2412 let message_4 = message_4.unwrap();
2413 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2414 assert_eq!(
2415 messages(&conversation, cx),
2416 vec![
2417 (message_1.id, Role::User, 0..4),
2418 (message_3.id, Role::User, 4..5),
2419 (message_2.id, Role::User, 5..9),
2420 (message_4.id, Role::User, 9..17),
2421 ]
2422 );
2423
2424 let (_, message_5) =
2425 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2426 let message_5 = message_5.unwrap();
2427 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
2428 assert_eq!(
2429 messages(&conversation, cx),
2430 vec![
2431 (message_1.id, Role::User, 0..4),
2432 (message_3.id, Role::User, 4..5),
2433 (message_2.id, Role::User, 5..9),
2434 (message_4.id, Role::User, 9..10),
2435 (message_5.id, Role::User, 10..18),
2436 ]
2437 );
2438
2439 let (message_6, message_7) = conversation.update(cx, |conversation, cx| {
2440 conversation.split_message(14..16, cx)
2441 });
2442 let message_6 = message_6.unwrap();
2443 let message_7 = message_7.unwrap();
2444 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
2445 assert_eq!(
2446 messages(&conversation, cx),
2447 vec![
2448 (message_1.id, Role::User, 0..4),
2449 (message_3.id, Role::User, 4..5),
2450 (message_2.id, Role::User, 5..9),
2451 (message_4.id, Role::User, 9..10),
2452 (message_5.id, Role::User, 10..14),
2453 (message_6.id, Role::User, 14..17),
2454 (message_7.id, Role::User, 17..19),
2455 ]
2456 );
2457 }
2458
2459 #[gpui::test]
2460 fn test_messages_for_offsets(cx: &mut AppContext) {
2461 let registry = Arc::new(LanguageRegistry::test());
2462 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2463 let buffer = conversation.read(cx).buffer.clone();
2464
2465 let message_1 = conversation.read(cx).message_anchors[0].clone();
2466 assert_eq!(
2467 messages(&conversation, cx),
2468 vec![(message_1.id, Role::User, 0..0)]
2469 );
2470
2471 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
2472 let message_2 = conversation
2473 .update(cx, |conversation, cx| {
2474 conversation.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
2475 })
2476 .unwrap();
2477 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
2478
2479 let message_3 = conversation
2480 .update(cx, |conversation, cx| {
2481 conversation.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2482 })
2483 .unwrap();
2484 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
2485
2486 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
2487 assert_eq!(
2488 messages(&conversation, cx),
2489 vec![
2490 (message_1.id, Role::User, 0..4),
2491 (message_2.id, Role::User, 4..8),
2492 (message_3.id, Role::User, 8..11)
2493 ]
2494 );
2495
2496 assert_eq!(
2497 message_ids_for_offsets(&conversation, &[0, 4, 9], cx),
2498 [message_1.id, message_2.id, message_3.id]
2499 );
2500 assert_eq!(
2501 message_ids_for_offsets(&conversation, &[0, 1, 11], cx),
2502 [message_1.id, message_3.id]
2503 );
2504
2505 let message_4 = conversation
2506 .update(cx, |conversation, cx| {
2507 conversation.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
2508 })
2509 .unwrap();
2510 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
2511 assert_eq!(
2512 messages(&conversation, cx),
2513 vec![
2514 (message_1.id, Role::User, 0..4),
2515 (message_2.id, Role::User, 4..8),
2516 (message_3.id, Role::User, 8..12),
2517 (message_4.id, Role::User, 12..12)
2518 ]
2519 );
2520 assert_eq!(
2521 message_ids_for_offsets(&conversation, &[0, 4, 8, 12], cx),
2522 [message_1.id, message_2.id, message_3.id, message_4.id]
2523 );
2524
2525 fn message_ids_for_offsets(
2526 conversation: &ModelHandle<Conversation>,
2527 offsets: &[usize],
2528 cx: &AppContext,
2529 ) -> Vec<MessageId> {
2530 conversation
2531 .read(cx)
2532 .messages_for_offsets(offsets.iter().copied(), cx)
2533 .into_iter()
2534 .map(|message| message.id)
2535 .collect()
2536 }
2537 }
2538
2539 #[gpui::test]
2540 fn test_serialization(cx: &mut AppContext) {
2541 let registry = Arc::new(LanguageRegistry::test());
2542 let conversation =
2543 cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
2544 let buffer = conversation.read(cx).buffer.clone();
2545 let message_0 = conversation.read(cx).message_anchors[0].id;
2546 let message_1 = conversation.update(cx, |conversation, cx| {
2547 conversation
2548 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
2549 .unwrap()
2550 });
2551 let message_2 = conversation.update(cx, |conversation, cx| {
2552 conversation
2553 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2554 .unwrap()
2555 });
2556 buffer.update(cx, |buffer, cx| {
2557 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
2558 buffer.finalize_last_transaction();
2559 });
2560 let _message_3 = conversation.update(cx, |conversation, cx| {
2561 conversation
2562 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
2563 .unwrap()
2564 });
2565 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2566 assert_eq!(buffer.read(cx).text(), "a\nb\nc\n");
2567 assert_eq!(
2568 messages(&conversation, cx),
2569 [
2570 (message_0, Role::User, 0..2),
2571 (message_1.id, Role::Assistant, 2..6),
2572 (message_2.id, Role::System, 6..6),
2573 ]
2574 );
2575
2576 let deserialized_conversation = cx.add_model(|cx| {
2577 Conversation::deserialize(
2578 conversation.read(cx).serialize(cx),
2579 Default::default(),
2580 Default::default(),
2581 registry.clone(),
2582 cx,
2583 )
2584 });
2585 let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone();
2586 assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n");
2587 assert_eq!(
2588 messages(&deserialized_conversation, cx),
2589 [
2590 (message_0, Role::User, 0..2),
2591 (message_1.id, Role::Assistant, 2..6),
2592 (message_2.id, Role::System, 6..6),
2593 ]
2594 );
2595 }
2596
2597 fn messages(
2598 conversation: &ModelHandle<Conversation>,
2599 cx: &AppContext,
2600 ) -> Vec<(MessageId, Role, Range<usize>)> {
2601 conversation
2602 .read(cx)
2603 .messages(cx)
2604 .map(|message| (message.id, message.role, message.offset_range))
2605 .collect()
2606 }
2607}