1use crate::{
2 assistant_settings::{AssistantDockPosition, AssistantSettings},
3 MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
4 RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
5};
6use anyhow::{anyhow, Result};
7use chrono::{DateTime, Local};
8use collections::{HashMap, HashSet};
9use editor::{
10 display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
11 scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
12 Anchor, Editor, ToOffset,
13};
14use fs::Fs;
15use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
16use gpui::{
17 actions,
18 elements::*,
19 executor::Background,
20 geometry::vector::{vec2f, Vector2F},
21 platform::{CursorStyle, MouseButton},
22 Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
23 Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
24};
25use isahc::{http::StatusCode, Request, RequestExt};
26use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
27use search::BufferSearchBar;
28use serde::Deserialize;
29use settings::SettingsStore;
30use std::{
31 cell::RefCell,
32 cmp, env,
33 fmt::Write,
34 io, iter,
35 ops::Range,
36 path::{Path, PathBuf},
37 rc::Rc,
38 sync::Arc,
39 time::Duration,
40};
41use theme::AssistantStyle;
42use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
43use workspace::{
44 dock::{DockPosition, Panel},
45 searchable::Direction,
46 Save, ToggleZoom, Toolbar, Workspace,
47};
48
49const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
50
51actions!(
52 assistant,
53 [
54 NewConversation,
55 Assist,
56 Split,
57 CycleMessageRole,
58 QuoteSelection,
59 ToggleFocus,
60 ResetKey,
61 ]
62);
63
64pub fn init(cx: &mut AppContext) {
65 settings::register::<AssistantSettings>(cx);
66 cx.add_action(
67 |this: &mut AssistantPanel,
68 _: &workspace::NewFile,
69 cx: &mut ViewContext<AssistantPanel>| {
70 this.new_conversation(cx);
71 },
72 );
73 cx.add_action(ConversationEditor::assist);
74 cx.capture_action(ConversationEditor::cancel_last_assist);
75 cx.capture_action(ConversationEditor::save);
76 cx.add_action(ConversationEditor::quote_selection);
77 cx.capture_action(ConversationEditor::copy);
78 cx.add_action(ConversationEditor::split);
79 cx.capture_action(ConversationEditor::cycle_message_role);
80 cx.add_action(AssistantPanel::save_api_key);
81 cx.add_action(AssistantPanel::reset_api_key);
82 cx.add_action(AssistantPanel::toggle_zoom);
83 cx.add_action(AssistantPanel::deploy);
84 cx.add_action(AssistantPanel::select_next_match);
85 cx.add_action(AssistantPanel::select_prev_match);
86 cx.add_action(AssistantPanel::handle_editor_cancel);
87 cx.add_action(
88 |workspace: &mut Workspace, _: &ToggleFocus, cx: &mut ViewContext<Workspace>| {
89 workspace.toggle_panel_focus::<AssistantPanel>(cx);
90 },
91 );
92}
93
94#[derive(Debug)]
95pub enum AssistantPanelEvent {
96 ZoomIn,
97 ZoomOut,
98 Focus,
99 Close,
100 DockPositionChanged,
101}
102
103pub struct AssistantPanel {
104 workspace: WeakViewHandle<Workspace>,
105 width: Option<f32>,
106 height: Option<f32>,
107 active_editor_index: Option<usize>,
108 prev_active_editor_index: Option<usize>,
109 editors: Vec<ViewHandle<ConversationEditor>>,
110 saved_conversations: Vec<SavedConversationMetadata>,
111 saved_conversations_list_state: UniformListState,
112 zoomed: bool,
113 has_focus: bool,
114 toolbar: ViewHandle<Toolbar>,
115 api_key: Rc<RefCell<Option<String>>>,
116 api_key_editor: Option<ViewHandle<Editor>>,
117 has_read_credentials: bool,
118 languages: Arc<LanguageRegistry>,
119 fs: Arc<dyn Fs>,
120 subscriptions: Vec<Subscription>,
121 _watch_saved_conversations: Task<Result<()>>,
122}
123
124impl AssistantPanel {
125 pub fn load(
126 workspace: WeakViewHandle<Workspace>,
127 cx: AsyncAppContext,
128 ) -> Task<Result<ViewHandle<Self>>> {
129 cx.spawn(|mut cx| async move {
130 let fs = workspace.read_with(&cx, |workspace, _| workspace.app_state().fs.clone())?;
131 let saved_conversations = SavedConversationMetadata::list(fs.clone())
132 .await
133 .log_err()
134 .unwrap_or_default();
135
136 // TODO: deserialize state.
137 let workspace_handle = workspace.clone();
138 workspace.update(&mut cx, |workspace, cx| {
139 cx.add_view::<Self, _>(|cx| {
140 const CONVERSATION_WATCH_DURATION: Duration = Duration::from_millis(100);
141 let _watch_saved_conversations = cx.spawn(move |this, mut cx| async move {
142 let mut events = fs
143 .watch(&CONVERSATIONS_DIR, CONVERSATION_WATCH_DURATION)
144 .await;
145 while events.next().await.is_some() {
146 let saved_conversations = SavedConversationMetadata::list(fs.clone())
147 .await
148 .log_err()
149 .unwrap_or_default();
150 this.update(&mut cx, |this, cx| {
151 this.saved_conversations = saved_conversations;
152 cx.notify();
153 })
154 .ok();
155 }
156
157 anyhow::Ok(())
158 });
159
160 let toolbar = cx.add_view(|cx| {
161 let mut toolbar = Toolbar::new(None);
162 toolbar.set_can_navigate(false, cx);
163 toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
164 toolbar
165 });
166 let mut this = Self {
167 workspace: workspace_handle,
168 active_editor_index: Default::default(),
169 prev_active_editor_index: Default::default(),
170 editors: Default::default(),
171 saved_conversations,
172 saved_conversations_list_state: Default::default(),
173 zoomed: false,
174 has_focus: false,
175 toolbar,
176 api_key: Rc::new(RefCell::new(None)),
177 api_key_editor: None,
178 has_read_credentials: false,
179 languages: workspace.app_state().languages.clone(),
180 fs: workspace.app_state().fs.clone(),
181 width: None,
182 height: None,
183 subscriptions: Default::default(),
184 _watch_saved_conversations,
185 };
186
187 let mut old_dock_position = this.position(cx);
188 this.subscriptions =
189 vec![cx.observe_global::<SettingsStore, _>(move |this, cx| {
190 let new_dock_position = this.position(cx);
191 if new_dock_position != old_dock_position {
192 old_dock_position = new_dock_position;
193 cx.emit(AssistantPanelEvent::DockPositionChanged);
194 }
195 })];
196
197 this
198 })
199 })
200 })
201 }
202
203 fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
204 let editor = cx.add_view(|cx| {
205 ConversationEditor::new(
206 self.api_key.clone(),
207 self.languages.clone(),
208 self.fs.clone(),
209 cx,
210 )
211 });
212 self.add_conversation(editor.clone(), cx);
213 editor
214 }
215
216 fn add_conversation(
217 &mut self,
218 editor: ViewHandle<ConversationEditor>,
219 cx: &mut ViewContext<Self>,
220 ) {
221 self.subscriptions
222 .push(cx.subscribe(&editor, Self::handle_conversation_editor_event));
223
224 let conversation = editor.read(cx).conversation.clone();
225 self.subscriptions
226 .push(cx.observe(&conversation, |_, _, cx| cx.notify()));
227
228 let index = self.editors.len();
229 self.editors.push(editor);
230 self.set_active_editor_index(Some(index), cx);
231 }
232
233 fn set_active_editor_index(&mut self, index: Option<usize>, cx: &mut ViewContext<Self>) {
234 self.prev_active_editor_index = self.active_editor_index;
235 self.active_editor_index = index;
236 if let Some(editor) = self.active_editor() {
237 let editor = editor.read(cx).editor.clone();
238 self.toolbar.update(cx, |toolbar, cx| {
239 toolbar.set_active_item(Some(&editor), cx);
240 });
241 if self.has_focus(cx) {
242 cx.focus(&editor);
243 }
244 } else {
245 self.toolbar.update(cx, |toolbar, cx| {
246 toolbar.set_active_item(None, cx);
247 });
248 }
249
250 cx.notify();
251 }
252
253 fn handle_conversation_editor_event(
254 &mut self,
255 _: ViewHandle<ConversationEditor>,
256 event: &ConversationEditorEvent,
257 cx: &mut ViewContext<Self>,
258 ) {
259 match event {
260 ConversationEditorEvent::TabContentChanged => cx.notify(),
261 }
262 }
263
264 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
265 if let Some(api_key) = self
266 .api_key_editor
267 .as_ref()
268 .map(|editor| editor.read(cx).text(cx))
269 {
270 if !api_key.is_empty() {
271 cx.platform()
272 .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
273 .log_err();
274 *self.api_key.borrow_mut() = Some(api_key);
275 self.api_key_editor.take();
276 cx.focus_self();
277 cx.notify();
278 }
279 } else {
280 cx.propagate_action();
281 }
282 }
283
284 fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
285 cx.platform().delete_credentials(OPENAI_API_URL).log_err();
286 self.api_key.take();
287 self.api_key_editor = Some(build_api_key_editor(cx));
288 cx.focus_self();
289 cx.notify();
290 }
291
292 fn toggle_zoom(&mut self, _: &workspace::ToggleZoom, cx: &mut ViewContext<Self>) {
293 if self.zoomed {
294 cx.emit(AssistantPanelEvent::ZoomOut)
295 } else {
296 cx.emit(AssistantPanelEvent::ZoomIn)
297 }
298 }
299
300 fn deploy(&mut self, action: &search::buffer_search::Deploy, cx: &mut ViewContext<Self>) {
301 let mut propagate_action = true;
302 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
303 search_bar.update(cx, |search_bar, cx| {
304 if search_bar.show(cx) {
305 search_bar.search_suggested(cx);
306 if action.focus {
307 search_bar.select_query(cx);
308 cx.focus_self();
309 }
310 propagate_action = false
311 }
312 });
313 }
314 if propagate_action {
315 cx.propagate_action();
316 }
317 }
318
319 fn handle_editor_cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
320 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
321 if !search_bar.read(cx).is_dismissed() {
322 search_bar.update(cx, |search_bar, cx| {
323 search_bar.dismiss(&Default::default(), cx)
324 });
325 return;
326 }
327 }
328 cx.propagate_action();
329 }
330
331 fn select_next_match(&mut self, _: &search::SelectNextMatch, cx: &mut ViewContext<Self>) {
332 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
333 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Next, 1, cx));
334 }
335 }
336
337 fn select_prev_match(&mut self, _: &search::SelectPrevMatch, cx: &mut ViewContext<Self>) {
338 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
339 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Prev, 1, cx));
340 }
341 }
342
343 fn active_editor(&self) -> Option<&ViewHandle<ConversationEditor>> {
344 self.editors.get(self.active_editor_index?)
345 }
346
347 fn render_hamburger_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
348 enum History {}
349 let theme = theme::current(cx);
350 let tooltip_style = theme::current(cx).tooltip.clone();
351 MouseEventHandler::<History, _>::new(0, cx, |state, _| {
352 let style = theme.assistant.hamburger_button.style_for(state);
353 Svg::for_style(style.icon.clone())
354 .contained()
355 .with_style(style.container)
356 })
357 .with_cursor_style(CursorStyle::PointingHand)
358 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
359 if this.active_editor().is_some() {
360 this.set_active_editor_index(None, cx);
361 } else {
362 this.set_active_editor_index(this.prev_active_editor_index, cx);
363 }
364 })
365 .with_tooltip::<History>(1, "History".into(), None, tooltip_style, cx)
366 }
367
368 fn render_editor_tools(&self, cx: &mut ViewContext<Self>) -> Vec<AnyElement<Self>> {
369 if self.active_editor().is_some() {
370 vec![
371 Self::render_split_button(cx).into_any(),
372 Self::render_quote_button(cx).into_any(),
373 Self::render_assist_button(cx).into_any(),
374 ]
375 } else {
376 Default::default()
377 }
378 }
379
380 fn render_split_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
381 let theme = theme::current(cx);
382 let tooltip_style = theme::current(cx).tooltip.clone();
383 MouseEventHandler::<Split, _>::new(0, cx, |state, _| {
384 let style = theme.assistant.split_button.style_for(state);
385 Svg::for_style(style.icon.clone())
386 .contained()
387 .with_style(style.container)
388 })
389 .with_cursor_style(CursorStyle::PointingHand)
390 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
391 if let Some(active_editor) = this.active_editor() {
392 active_editor.update(cx, |editor, cx| editor.split(&Default::default(), cx));
393 }
394 })
395 .with_tooltip::<Split>(
396 1,
397 "Split Message".into(),
398 Some(Box::new(Split)),
399 tooltip_style,
400 cx,
401 )
402 }
403
404 fn render_assist_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
405 let theme = theme::current(cx);
406 let tooltip_style = theme::current(cx).tooltip.clone();
407 MouseEventHandler::<Assist, _>::new(0, cx, |state, _| {
408 let style = theme.assistant.assist_button.style_for(state);
409 Svg::for_style(style.icon.clone())
410 .contained()
411 .with_style(style.container)
412 })
413 .with_cursor_style(CursorStyle::PointingHand)
414 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
415 if let Some(active_editor) = this.active_editor() {
416 active_editor.update(cx, |editor, cx| editor.assist(&Default::default(), cx));
417 }
418 })
419 .with_tooltip::<Assist>(
420 1,
421 "Assist".into(),
422 Some(Box::new(Assist)),
423 tooltip_style,
424 cx,
425 )
426 }
427
428 fn render_quote_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
429 let theme = theme::current(cx);
430 let tooltip_style = theme::current(cx).tooltip.clone();
431 MouseEventHandler::<QuoteSelection, _>::new(0, cx, |state, _| {
432 let style = theme.assistant.quote_button.style_for(state);
433 Svg::for_style(style.icon.clone())
434 .contained()
435 .with_style(style.container)
436 })
437 .with_cursor_style(CursorStyle::PointingHand)
438 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
439 if let Some(workspace) = this.workspace.upgrade(cx) {
440 cx.window_context().defer(move |cx| {
441 workspace.update(cx, |workspace, cx| {
442 ConversationEditor::quote_selection(workspace, &Default::default(), cx)
443 });
444 });
445 }
446 })
447 .with_tooltip::<QuoteSelection>(
448 1,
449 "Quote Selection".into(),
450 Some(Box::new(QuoteSelection)),
451 tooltip_style,
452 cx,
453 )
454 }
455
456 fn render_plus_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
457 let theme = theme::current(cx);
458 let tooltip_style = theme::current(cx).tooltip.clone();
459 MouseEventHandler::<NewConversation, _>::new(0, cx, |state, _| {
460 let style = theme.assistant.plus_button.style_for(state);
461 Svg::for_style(style.icon.clone())
462 .contained()
463 .with_style(style.container)
464 })
465 .with_cursor_style(CursorStyle::PointingHand)
466 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
467 this.new_conversation(cx);
468 })
469 .with_tooltip::<NewConversation>(
470 1,
471 "New Conversation".into(),
472 Some(Box::new(NewConversation)),
473 tooltip_style,
474 cx,
475 )
476 }
477
478 fn render_zoom_button(&self, cx: &mut ViewContext<Self>) -> impl Element<Self> {
479 enum ToggleZoomButton {}
480
481 let theme = theme::current(cx);
482 let tooltip_style = theme::current(cx).tooltip.clone();
483 let style = if self.zoomed {
484 &theme.assistant.zoom_out_button
485 } else {
486 &theme.assistant.zoom_in_button
487 };
488
489 MouseEventHandler::<ToggleZoomButton, _>::new(0, cx, |state, _| {
490 let style = style.style_for(state);
491 Svg::for_style(style.icon.clone())
492 .contained()
493 .with_style(style.container)
494 })
495 .with_cursor_style(CursorStyle::PointingHand)
496 .on_click(MouseButton::Left, |_, this, cx| {
497 this.toggle_zoom(&ToggleZoom, cx);
498 })
499 .with_tooltip::<ToggleZoom>(
500 0,
501 if self.zoomed {
502 "Zoom Out".into()
503 } else {
504 "Zoom In".into()
505 },
506 Some(Box::new(ToggleZoom)),
507 tooltip_style,
508 cx,
509 )
510 }
511
512 fn render_saved_conversation(
513 &mut self,
514 index: usize,
515 cx: &mut ViewContext<Self>,
516 ) -> impl Element<Self> {
517 let conversation = &self.saved_conversations[index];
518 let path = conversation.path.clone();
519 MouseEventHandler::<SavedConversationMetadata, _>::new(index, cx, move |state, cx| {
520 let style = &theme::current(cx).assistant.saved_conversation;
521 Flex::row()
522 .with_child(
523 Label::new(
524 conversation.mtime.format("%F %I:%M%p").to_string(),
525 style.saved_at.text.clone(),
526 )
527 .aligned()
528 .contained()
529 .with_style(style.saved_at.container),
530 )
531 .with_child(
532 Label::new(conversation.title.clone(), style.title.text.clone())
533 .aligned()
534 .contained()
535 .with_style(style.title.container),
536 )
537 .contained()
538 .with_style(*style.container.style_for(state))
539 })
540 .with_cursor_style(CursorStyle::PointingHand)
541 .on_click(MouseButton::Left, move |_, this, cx| {
542 this.open_conversation(path.clone(), cx)
543 .detach_and_log_err(cx)
544 })
545 }
546
547 fn open_conversation(&mut self, path: PathBuf, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
548 if let Some(ix) = self.editor_index_for_path(&path, cx) {
549 self.set_active_editor_index(Some(ix), cx);
550 return Task::ready(Ok(()));
551 }
552
553 let fs = self.fs.clone();
554 let api_key = self.api_key.clone();
555 let languages = self.languages.clone();
556 cx.spawn(|this, mut cx| async move {
557 let saved_conversation = fs.load(&path).await?;
558 let saved_conversation = serde_json::from_str(&saved_conversation)?;
559 let conversation = cx.add_model(|cx| {
560 Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
561 });
562 this.update(&mut cx, |this, cx| {
563 // If, by the time we've loaded the conversation, the user has already opened
564 // the same conversation, we don't want to open it again.
565 if let Some(ix) = this.editor_index_for_path(&path, cx) {
566 this.set_active_editor_index(Some(ix), cx);
567 } else {
568 let editor = cx
569 .add_view(|cx| ConversationEditor::for_conversation(conversation, fs, cx));
570 this.add_conversation(editor, cx);
571 }
572 })?;
573 Ok(())
574 })
575 }
576
577 fn editor_index_for_path(&self, path: &Path, cx: &AppContext) -> Option<usize> {
578 self.editors
579 .iter()
580 .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
581 }
582}
583
584fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
585 cx.add_view(|cx| {
586 let mut editor = Editor::single_line(
587 Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
588 cx,
589 );
590 editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
591 editor
592 })
593}
594
595impl Entity for AssistantPanel {
596 type Event = AssistantPanelEvent;
597}
598
599impl View for AssistantPanel {
600 fn ui_name() -> &'static str {
601 "AssistantPanel"
602 }
603
604 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
605 let theme = &theme::current(cx);
606 let style = &theme.assistant;
607 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
608 Flex::column()
609 .with_child(
610 Text::new(
611 "Paste your OpenAI API key and press Enter to use the assistant",
612 style.api_key_prompt.text.clone(),
613 )
614 .aligned(),
615 )
616 .with_child(
617 ChildView::new(api_key_editor, cx)
618 .contained()
619 .with_style(style.api_key_editor.container)
620 .aligned(),
621 )
622 .contained()
623 .with_style(style.api_key_prompt.container)
624 .aligned()
625 .into_any()
626 } else {
627 let title = self.active_editor().map(|editor| {
628 Label::new(editor.read(cx).title(cx), style.title.text.clone())
629 .contained()
630 .with_style(style.title.container)
631 .aligned()
632 .left()
633 .flex(1., false)
634 });
635 let mut header = Flex::row()
636 .with_child(Self::render_hamburger_button(cx).aligned())
637 .with_children(title);
638 if self.has_focus {
639 header.add_children(
640 self.render_editor_tools(cx)
641 .into_iter()
642 .map(|tool| tool.aligned().flex_float()),
643 );
644 header.add_child(Self::render_plus_button(cx).aligned().flex_float());
645 header.add_child(self.render_zoom_button(cx).aligned());
646 }
647
648 Flex::column()
649 .with_child(
650 header
651 .contained()
652 .with_style(theme.workspace.tab_bar.container)
653 .expanded()
654 .constrained()
655 .with_height(theme.workspace.tab_bar.height),
656 )
657 .with_children(if self.toolbar.read(cx).hidden() {
658 None
659 } else {
660 Some(ChildView::new(&self.toolbar, cx).expanded())
661 })
662 .with_child(if let Some(editor) = self.active_editor() {
663 ChildView::new(editor, cx).flex(1., true).into_any()
664 } else {
665 UniformList::new(
666 self.saved_conversations_list_state.clone(),
667 self.saved_conversations.len(),
668 cx,
669 |this, range, items, cx| {
670 for ix in range {
671 items.push(this.render_saved_conversation(ix, cx).into_any());
672 }
673 },
674 )
675 .flex(1., true)
676 .into_any()
677 })
678 .into_any()
679 }
680 }
681
682 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
683 self.has_focus = true;
684 self.toolbar
685 .update(cx, |toolbar, cx| toolbar.focus_changed(true, cx));
686 cx.notify();
687 if cx.is_self_focused() {
688 if let Some(editor) = self.active_editor() {
689 cx.focus(editor);
690 } else if let Some(api_key_editor) = self.api_key_editor.as_ref() {
691 cx.focus(api_key_editor);
692 }
693 }
694 }
695
696 fn focus_out(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
697 self.has_focus = false;
698 self.toolbar
699 .update(cx, |toolbar, cx| toolbar.focus_changed(false, cx));
700 cx.notify();
701 }
702}
703
704impl Panel for AssistantPanel {
705 fn position(&self, cx: &WindowContext) -> DockPosition {
706 match settings::get::<AssistantSettings>(cx).dock {
707 AssistantDockPosition::Left => DockPosition::Left,
708 AssistantDockPosition::Bottom => DockPosition::Bottom,
709 AssistantDockPosition::Right => DockPosition::Right,
710 }
711 }
712
713 fn position_is_valid(&self, _: DockPosition) -> bool {
714 true
715 }
716
717 fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
718 settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
719 let dock = match position {
720 DockPosition::Left => AssistantDockPosition::Left,
721 DockPosition::Bottom => AssistantDockPosition::Bottom,
722 DockPosition::Right => AssistantDockPosition::Right,
723 };
724 settings.dock = Some(dock);
725 });
726 }
727
728 fn size(&self, cx: &WindowContext) -> f32 {
729 let settings = settings::get::<AssistantSettings>(cx);
730 match self.position(cx) {
731 DockPosition::Left | DockPosition::Right => {
732 self.width.unwrap_or_else(|| settings.default_width)
733 }
734 DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
735 }
736 }
737
738 fn set_size(&mut self, size: f32, cx: &mut ViewContext<Self>) {
739 match self.position(cx) {
740 DockPosition::Left | DockPosition::Right => self.width = Some(size),
741 DockPosition::Bottom => self.height = Some(size),
742 }
743 cx.notify();
744 }
745
746 fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
747 matches!(event, AssistantPanelEvent::ZoomIn)
748 }
749
750 fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
751 matches!(event, AssistantPanelEvent::ZoomOut)
752 }
753
754 fn is_zoomed(&self, _: &WindowContext) -> bool {
755 self.zoomed
756 }
757
758 fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
759 self.zoomed = zoomed;
760 cx.notify();
761 }
762
763 fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
764 if active {
765 if self.api_key.borrow().is_none() && !self.has_read_credentials {
766 self.has_read_credentials = true;
767 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
768 Some(api_key)
769 } else if let Some((_, api_key)) = cx
770 .platform()
771 .read_credentials(OPENAI_API_URL)
772 .log_err()
773 .flatten()
774 {
775 String::from_utf8(api_key).log_err()
776 } else {
777 None
778 };
779 if let Some(api_key) = api_key {
780 *self.api_key.borrow_mut() = Some(api_key);
781 } else if self.api_key_editor.is_none() {
782 self.api_key_editor = Some(build_api_key_editor(cx));
783 cx.notify();
784 }
785 }
786
787 if self.editors.is_empty() {
788 self.new_conversation(cx);
789 }
790 }
791 }
792
793 fn icon_path(&self) -> &'static str {
794 "icons/robot_14.svg"
795 }
796
797 fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
798 ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
799 }
800
801 fn should_change_position_on_event(event: &Self::Event) -> bool {
802 matches!(event, AssistantPanelEvent::DockPositionChanged)
803 }
804
805 fn should_activate_on_event(_: &Self::Event) -> bool {
806 false
807 }
808
809 fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
810 matches!(event, AssistantPanelEvent::Close)
811 }
812
813 fn has_focus(&self, _: &WindowContext) -> bool {
814 self.has_focus
815 }
816
817 fn is_focus_event(event: &Self::Event) -> bool {
818 matches!(event, AssistantPanelEvent::Focus)
819 }
820}
821
822enum ConversationEvent {
823 MessagesEdited,
824 SummaryChanged,
825 StreamedCompletion,
826}
827
828#[derive(Default)]
829struct Summary {
830 text: String,
831 done: bool,
832}
833
834struct Conversation {
835 buffer: ModelHandle<Buffer>,
836 message_anchors: Vec<MessageAnchor>,
837 messages_metadata: HashMap<MessageId, MessageMetadata>,
838 next_message_id: MessageId,
839 summary: Option<Summary>,
840 pending_summary: Task<Option<()>>,
841 completion_count: usize,
842 pending_completions: Vec<PendingCompletion>,
843 model: String,
844 token_count: Option<usize>,
845 max_token_count: usize,
846 pending_token_count: Task<Option<()>>,
847 api_key: Rc<RefCell<Option<String>>>,
848 pending_save: Task<Result<()>>,
849 path: Option<PathBuf>,
850 _subscriptions: Vec<Subscription>,
851}
852
853impl Entity for Conversation {
854 type Event = ConversationEvent;
855}
856
857impl Conversation {
858 fn new(
859 api_key: Rc<RefCell<Option<String>>>,
860 language_registry: Arc<LanguageRegistry>,
861 cx: &mut ModelContext<Self>,
862 ) -> Self {
863 let model = "gpt-3.5-turbo-0613";
864 let markdown = language_registry.language_for_name("Markdown");
865 let buffer = cx.add_model(|cx| {
866 let mut buffer = Buffer::new(0, "", cx);
867 buffer.set_language_registry(language_registry);
868 cx.spawn_weak(|buffer, mut cx| async move {
869 let markdown = markdown.await?;
870 let buffer = buffer
871 .upgrade(&cx)
872 .ok_or_else(|| anyhow!("buffer was dropped"))?;
873 buffer.update(&mut cx, |buffer, cx| {
874 buffer.set_language(Some(markdown), cx)
875 });
876 anyhow::Ok(())
877 })
878 .detach_and_log_err(cx);
879 buffer
880 });
881
882 let mut this = Self {
883 message_anchors: Default::default(),
884 messages_metadata: Default::default(),
885 next_message_id: Default::default(),
886 summary: None,
887 pending_summary: Task::ready(None),
888 completion_count: Default::default(),
889 pending_completions: Default::default(),
890 token_count: None,
891 max_token_count: tiktoken_rs::model::get_context_size(model),
892 pending_token_count: Task::ready(None),
893 model: model.into(),
894 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
895 pending_save: Task::ready(Ok(())),
896 path: None,
897 api_key,
898 buffer,
899 };
900 let message = MessageAnchor {
901 id: MessageId(post_inc(&mut this.next_message_id.0)),
902 start: language::Anchor::MIN,
903 };
904 this.message_anchors.push(message.clone());
905 this.messages_metadata.insert(
906 message.id,
907 MessageMetadata {
908 role: Role::User,
909 sent_at: Local::now(),
910 status: MessageStatus::Done,
911 },
912 );
913
914 this.count_remaining_tokens(cx);
915 this
916 }
917
918 fn serialize(&self, cx: &AppContext) -> SavedConversation {
919 SavedConversation {
920 zed: "conversation".into(),
921 version: SavedConversation::VERSION.into(),
922 text: self.buffer.read(cx).text(),
923 message_metadata: self.messages_metadata.clone(),
924 messages: self
925 .messages(cx)
926 .map(|message| SavedMessage {
927 id: message.id,
928 start: message.offset_range.start,
929 })
930 .collect(),
931 summary: self
932 .summary
933 .as_ref()
934 .map(|summary| summary.text.clone())
935 .unwrap_or_default(),
936 model: self.model.clone(),
937 }
938 }
939
940 fn deserialize(
941 saved_conversation: SavedConversation,
942 path: PathBuf,
943 api_key: Rc<RefCell<Option<String>>>,
944 language_registry: Arc<LanguageRegistry>,
945 cx: &mut ModelContext<Self>,
946 ) -> Self {
947 let model = saved_conversation.model;
948 let markdown = language_registry.language_for_name("Markdown");
949 let mut message_anchors = Vec::new();
950 let mut next_message_id = MessageId(0);
951 let buffer = cx.add_model(|cx| {
952 let mut buffer = Buffer::new(0, saved_conversation.text, cx);
953 for message in saved_conversation.messages {
954 message_anchors.push(MessageAnchor {
955 id: message.id,
956 start: buffer.anchor_before(message.start),
957 });
958 next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
959 }
960 buffer.set_language_registry(language_registry);
961 cx.spawn_weak(|buffer, mut cx| async move {
962 let markdown = markdown.await?;
963 let buffer = buffer
964 .upgrade(&cx)
965 .ok_or_else(|| anyhow!("buffer was dropped"))?;
966 buffer.update(&mut cx, |buffer, cx| {
967 buffer.set_language(Some(markdown), cx)
968 });
969 anyhow::Ok(())
970 })
971 .detach_and_log_err(cx);
972 buffer
973 });
974
975 let mut this = Self {
976 message_anchors,
977 messages_metadata: saved_conversation.message_metadata,
978 next_message_id,
979 summary: Some(Summary {
980 text: saved_conversation.summary,
981 done: true,
982 }),
983 pending_summary: Task::ready(None),
984 completion_count: Default::default(),
985 pending_completions: Default::default(),
986 token_count: None,
987 max_token_count: tiktoken_rs::model::get_context_size(&model),
988 pending_token_count: Task::ready(None),
989 model,
990 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
991 pending_save: Task::ready(Ok(())),
992 path: Some(path),
993 api_key,
994 buffer,
995 };
996 this.count_remaining_tokens(cx);
997 this
998 }
999
1000 fn handle_buffer_event(
1001 &mut self,
1002 _: ModelHandle<Buffer>,
1003 event: &language::Event,
1004 cx: &mut ModelContext<Self>,
1005 ) {
1006 match event {
1007 language::Event::Edited => {
1008 self.count_remaining_tokens(cx);
1009 cx.emit(ConversationEvent::MessagesEdited);
1010 }
1011 _ => {}
1012 }
1013 }
1014
1015 fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
1016 let messages = self
1017 .messages(cx)
1018 .into_iter()
1019 .filter_map(|message| {
1020 Some(tiktoken_rs::ChatCompletionRequestMessage {
1021 role: match message.role {
1022 Role::User => "user".into(),
1023 Role::Assistant => "assistant".into(),
1024 Role::System => "system".into(),
1025 },
1026 content: self
1027 .buffer
1028 .read(cx)
1029 .text_for_range(message.offset_range)
1030 .collect(),
1031 name: None,
1032 })
1033 })
1034 .collect::<Vec<_>>();
1035 let model = self.model.clone();
1036 self.pending_token_count = cx.spawn_weak(|this, mut cx| {
1037 async move {
1038 cx.background().timer(Duration::from_millis(200)).await;
1039 let token_count = cx
1040 .background()
1041 .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
1042 .await?;
1043
1044 this.upgrade(&cx)
1045 .ok_or_else(|| anyhow!("conversation was dropped"))?
1046 .update(&mut cx, |this, cx| {
1047 this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
1048 this.token_count = Some(token_count);
1049 cx.notify()
1050 });
1051 anyhow::Ok(())
1052 }
1053 .log_err()
1054 });
1055 }
1056
1057 fn remaining_tokens(&self) -> Option<isize> {
1058 Some(self.max_token_count as isize - self.token_count? as isize)
1059 }
1060
1061 fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
1062 self.model = model;
1063 self.count_remaining_tokens(cx);
1064 cx.notify();
1065 }
1066
1067 fn assist(
1068 &mut self,
1069 selected_messages: HashSet<MessageId>,
1070 cx: &mut ModelContext<Self>,
1071 ) -> Vec<MessageAnchor> {
1072 let mut user_messages = Vec::new();
1073 let mut tasks = Vec::new();
1074
1075 let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
1076 message
1077 .start
1078 .is_valid(self.buffer.read(cx))
1079 .then_some(message.id)
1080 });
1081
1082 for selected_message_id in selected_messages {
1083 let selected_message_role =
1084 if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
1085 metadata.role
1086 } else {
1087 continue;
1088 };
1089
1090 if selected_message_role == Role::Assistant {
1091 if let Some(user_message) = self.insert_message_after(
1092 selected_message_id,
1093 Role::User,
1094 MessageStatus::Done,
1095 cx,
1096 ) {
1097 user_messages.push(user_message);
1098 } else {
1099 continue;
1100 }
1101 } else {
1102 let request = OpenAIRequest {
1103 model: self.model.clone(),
1104 messages: self
1105 .messages(cx)
1106 .filter(|message| matches!(message.status, MessageStatus::Done))
1107 .flat_map(|message| {
1108 let mut system_message = None;
1109 if message.id == selected_message_id {
1110 system_message = Some(RequestMessage {
1111 role: Role::System,
1112 content: concat!(
1113 "Treat the following messages as additional knowledge you have learned about, ",
1114 "but act as if they were not part of this conversation. That is, treat them ",
1115 "as if the user didn't see them and couldn't possibly inquire about them."
1116 ).into()
1117 });
1118 }
1119
1120 Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
1121 })
1122 .chain(Some(RequestMessage {
1123 role: Role::System,
1124 content: format!(
1125 "Direct your reply to message with id {}. Do not include a [Message X] header.",
1126 selected_message_id.0
1127 ),
1128 }))
1129 .collect(),
1130 stream: true,
1131 };
1132
1133 let Some(api_key) = self.api_key.borrow().clone() else { continue };
1134 let stream = stream_completion(api_key, cx.background().clone(), request);
1135 let assistant_message = self
1136 .insert_message_after(
1137 selected_message_id,
1138 Role::Assistant,
1139 MessageStatus::Pending,
1140 cx,
1141 )
1142 .unwrap();
1143
1144 // Queue up the user's next reply
1145 if Some(selected_message_id) == last_message_id {
1146 let user_message = self
1147 .insert_message_after(
1148 assistant_message.id,
1149 Role::User,
1150 MessageStatus::Done,
1151 cx,
1152 )
1153 .unwrap();
1154 user_messages.push(user_message);
1155 }
1156
1157 tasks.push(cx.spawn_weak({
1158 |this, mut cx| async move {
1159 let assistant_message_id = assistant_message.id;
1160 let stream_completion = async {
1161 let mut messages = stream.await?;
1162
1163 while let Some(message) = messages.next().await {
1164 let mut message = message?;
1165 if let Some(choice) = message.choices.pop() {
1166 this.upgrade(&cx)
1167 .ok_or_else(|| anyhow!("conversation was dropped"))?
1168 .update(&mut cx, |this, cx| {
1169 let text: Arc<str> = choice.delta.content?.into();
1170 let message_ix = this.message_anchors.iter().position(
1171 |message| message.id == assistant_message_id,
1172 )?;
1173 this.buffer.update(cx, |buffer, cx| {
1174 let offset = this.message_anchors[message_ix + 1..]
1175 .iter()
1176 .find(|message| message.start.is_valid(buffer))
1177 .map_or(buffer.len(), |message| {
1178 message
1179 .start
1180 .to_offset(buffer)
1181 .saturating_sub(1)
1182 });
1183 buffer.edit([(offset..offset, text)], None, cx);
1184 });
1185 cx.emit(ConversationEvent::StreamedCompletion);
1186
1187 Some(())
1188 });
1189 }
1190 smol::future::yield_now().await;
1191 }
1192
1193 this.upgrade(&cx)
1194 .ok_or_else(|| anyhow!("conversation was dropped"))?
1195 .update(&mut cx, |this, cx| {
1196 this.pending_completions.retain(|completion| {
1197 completion.id != this.completion_count
1198 });
1199 this.summarize(cx);
1200 });
1201
1202 anyhow::Ok(())
1203 };
1204
1205 let result = stream_completion.await;
1206 if let Some(this) = this.upgrade(&cx) {
1207 this.update(&mut cx, |this, cx| {
1208 if let Some(metadata) =
1209 this.messages_metadata.get_mut(&assistant_message.id)
1210 {
1211 match result {
1212 Ok(_) => {
1213 metadata.status = MessageStatus::Done;
1214 }
1215 Err(error) => {
1216 metadata.status = MessageStatus::Error(
1217 error.to_string().trim().into(),
1218 );
1219 }
1220 }
1221 cx.notify();
1222 }
1223 });
1224 }
1225 }
1226 }));
1227 }
1228 }
1229
1230 if !tasks.is_empty() {
1231 self.pending_completions.push(PendingCompletion {
1232 id: post_inc(&mut self.completion_count),
1233 _tasks: tasks,
1234 });
1235 }
1236
1237 user_messages
1238 }
1239
1240 fn cancel_last_assist(&mut self) -> bool {
1241 self.pending_completions.pop().is_some()
1242 }
1243
1244 fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
1245 for id in ids {
1246 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
1247 metadata.role.cycle();
1248 cx.emit(ConversationEvent::MessagesEdited);
1249 cx.notify();
1250 }
1251 }
1252 }
1253
1254 fn insert_message_after(
1255 &mut self,
1256 message_id: MessageId,
1257 role: Role,
1258 status: MessageStatus,
1259 cx: &mut ModelContext<Self>,
1260 ) -> Option<MessageAnchor> {
1261 if let Some(prev_message_ix) = self
1262 .message_anchors
1263 .iter()
1264 .position(|message| message.id == message_id)
1265 {
1266 // Find the next valid message after the one we were given.
1267 let mut next_message_ix = prev_message_ix + 1;
1268 while let Some(next_message) = self.message_anchors.get(next_message_ix) {
1269 if next_message.start.is_valid(self.buffer.read(cx)) {
1270 break;
1271 }
1272 next_message_ix += 1;
1273 }
1274
1275 let start = self.buffer.update(cx, |buffer, cx| {
1276 let offset = self
1277 .message_anchors
1278 .get(next_message_ix)
1279 .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
1280 buffer.edit([(offset..offset, "\n")], None, cx);
1281 buffer.anchor_before(offset + 1)
1282 });
1283 let message = MessageAnchor {
1284 id: MessageId(post_inc(&mut self.next_message_id.0)),
1285 start,
1286 };
1287 self.message_anchors
1288 .insert(next_message_ix, message.clone());
1289 self.messages_metadata.insert(
1290 message.id,
1291 MessageMetadata {
1292 role,
1293 sent_at: Local::now(),
1294 status,
1295 },
1296 );
1297 cx.emit(ConversationEvent::MessagesEdited);
1298 Some(message)
1299 } else {
1300 None
1301 }
1302 }
1303
1304 fn split_message(
1305 &mut self,
1306 range: Range<usize>,
1307 cx: &mut ModelContext<Self>,
1308 ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
1309 let start_message = self.message_for_offset(range.start, cx);
1310 let end_message = self.message_for_offset(range.end, cx);
1311 if let Some((start_message, end_message)) = start_message.zip(end_message) {
1312 // Prevent splitting when range spans multiple messages.
1313 if start_message.id != end_message.id {
1314 return (None, None);
1315 }
1316
1317 let message = start_message;
1318 let role = message.role;
1319 let mut edited_buffer = false;
1320
1321 let mut suffix_start = None;
1322 if range.start > message.offset_range.start && range.end < message.offset_range.end - 1
1323 {
1324 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
1325 suffix_start = Some(range.end + 1);
1326 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
1327 suffix_start = Some(range.end);
1328 }
1329 }
1330
1331 let suffix = if let Some(suffix_start) = suffix_start {
1332 MessageAnchor {
1333 id: MessageId(post_inc(&mut self.next_message_id.0)),
1334 start: self.buffer.read(cx).anchor_before(suffix_start),
1335 }
1336 } else {
1337 self.buffer.update(cx, |buffer, cx| {
1338 buffer.edit([(range.end..range.end, "\n")], None, cx);
1339 });
1340 edited_buffer = true;
1341 MessageAnchor {
1342 id: MessageId(post_inc(&mut self.next_message_id.0)),
1343 start: self.buffer.read(cx).anchor_before(range.end + 1),
1344 }
1345 };
1346
1347 self.message_anchors
1348 .insert(message.index_range.end + 1, suffix.clone());
1349 self.messages_metadata.insert(
1350 suffix.id,
1351 MessageMetadata {
1352 role,
1353 sent_at: Local::now(),
1354 status: MessageStatus::Done,
1355 },
1356 );
1357
1358 let new_messages =
1359 if range.start == range.end || range.start == message.offset_range.start {
1360 (None, Some(suffix))
1361 } else {
1362 let mut prefix_end = None;
1363 if range.start > message.offset_range.start
1364 && range.end < message.offset_range.end - 1
1365 {
1366 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
1367 prefix_end = Some(range.start + 1);
1368 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
1369 == Some('\n')
1370 {
1371 prefix_end = Some(range.start);
1372 }
1373 }
1374
1375 let selection = if let Some(prefix_end) = prefix_end {
1376 cx.emit(ConversationEvent::MessagesEdited);
1377 MessageAnchor {
1378 id: MessageId(post_inc(&mut self.next_message_id.0)),
1379 start: self.buffer.read(cx).anchor_before(prefix_end),
1380 }
1381 } else {
1382 self.buffer.update(cx, |buffer, cx| {
1383 buffer.edit([(range.start..range.start, "\n")], None, cx)
1384 });
1385 edited_buffer = true;
1386 MessageAnchor {
1387 id: MessageId(post_inc(&mut self.next_message_id.0)),
1388 start: self.buffer.read(cx).anchor_before(range.end + 1),
1389 }
1390 };
1391
1392 self.message_anchors
1393 .insert(message.index_range.end + 1, selection.clone());
1394 self.messages_metadata.insert(
1395 selection.id,
1396 MessageMetadata {
1397 role,
1398 sent_at: Local::now(),
1399 status: MessageStatus::Done,
1400 },
1401 );
1402 (Some(selection), Some(suffix))
1403 };
1404
1405 if !edited_buffer {
1406 cx.emit(ConversationEvent::MessagesEdited);
1407 }
1408 new_messages
1409 } else {
1410 (None, None)
1411 }
1412 }
1413
1414 fn summarize(&mut self, cx: &mut ModelContext<Self>) {
1415 if self.message_anchors.len() >= 2 && self.summary.is_none() {
1416 let api_key = self.api_key.borrow().clone();
1417 if let Some(api_key) = api_key {
1418 let messages = self
1419 .messages(cx)
1420 .take(2)
1421 .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
1422 .chain(Some(RequestMessage {
1423 role: Role::User,
1424 content:
1425 "Summarize the conversation into a short title without punctuation"
1426 .into(),
1427 }));
1428 let request = OpenAIRequest {
1429 model: self.model.clone(),
1430 messages: messages.collect(),
1431 stream: true,
1432 };
1433
1434 let stream = stream_completion(api_key, cx.background().clone(), request);
1435 self.pending_summary = cx.spawn(|this, mut cx| {
1436 async move {
1437 let mut messages = stream.await?;
1438
1439 while let Some(message) = messages.next().await {
1440 let mut message = message?;
1441 if let Some(choice) = message.choices.pop() {
1442 let text = choice.delta.content.unwrap_or_default();
1443 this.update(&mut cx, |this, cx| {
1444 this.summary
1445 .get_or_insert(Default::default())
1446 .text
1447 .push_str(&text);
1448 cx.emit(ConversationEvent::SummaryChanged);
1449 });
1450 }
1451 }
1452
1453 this.update(&mut cx, |this, cx| {
1454 if let Some(summary) = this.summary.as_mut() {
1455 summary.done = true;
1456 cx.emit(ConversationEvent::SummaryChanged);
1457 }
1458 });
1459
1460 anyhow::Ok(())
1461 }
1462 .log_err()
1463 });
1464 }
1465 }
1466 }
1467
1468 fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
1469 self.messages_for_offsets([offset], cx).pop()
1470 }
1471
1472 fn messages_for_offsets(
1473 &self,
1474 offsets: impl IntoIterator<Item = usize>,
1475 cx: &AppContext,
1476 ) -> Vec<Message> {
1477 let mut result = Vec::new();
1478
1479 let mut messages = self.messages(cx).peekable();
1480 let mut offsets = offsets.into_iter().peekable();
1481 let mut current_message = messages.next();
1482 while let Some(offset) = offsets.next() {
1483 // Locate the message that contains the offset.
1484 while current_message.as_ref().map_or(false, |message| {
1485 !message.offset_range.contains(&offset) && messages.peek().is_some()
1486 }) {
1487 current_message = messages.next();
1488 }
1489 let Some(message) = current_message.as_ref() else { break };
1490
1491 // Skip offsets that are in the same message.
1492 while offsets.peek().map_or(false, |offset| {
1493 message.offset_range.contains(offset) || messages.peek().is_none()
1494 }) {
1495 offsets.next();
1496 }
1497
1498 result.push(message.clone());
1499 }
1500 result
1501 }
1502
1503 fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
1504 let buffer = self.buffer.read(cx);
1505 let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
1506 iter::from_fn(move || {
1507 while let Some((start_ix, message_anchor)) = message_anchors.next() {
1508 let metadata = self.messages_metadata.get(&message_anchor.id)?;
1509 let message_start = message_anchor.start.to_offset(buffer);
1510 let mut message_end = None;
1511 let mut end_ix = start_ix;
1512 while let Some((_, next_message)) = message_anchors.peek() {
1513 if next_message.start.is_valid(buffer) {
1514 message_end = Some(next_message.start);
1515 break;
1516 } else {
1517 end_ix += 1;
1518 message_anchors.next();
1519 }
1520 }
1521 let message_end = message_end
1522 .unwrap_or(language::Anchor::MAX)
1523 .to_offset(buffer);
1524 return Some(Message {
1525 index_range: start_ix..end_ix,
1526 offset_range: message_start..message_end,
1527 id: message_anchor.id,
1528 anchor: message_anchor.start,
1529 role: metadata.role,
1530 sent_at: metadata.sent_at,
1531 status: metadata.status.clone(),
1532 });
1533 }
1534 None
1535 })
1536 }
1537
1538 fn save(
1539 &mut self,
1540 debounce: Option<Duration>,
1541 fs: Arc<dyn Fs>,
1542 cx: &mut ModelContext<Conversation>,
1543 ) {
1544 self.pending_save = cx.spawn(|this, mut cx| async move {
1545 if let Some(debounce) = debounce {
1546 cx.background().timer(debounce).await;
1547 }
1548
1549 let (old_path, summary) = this.read_with(&cx, |this, _| {
1550 let path = this.path.clone();
1551 let summary = if let Some(summary) = this.summary.as_ref() {
1552 if summary.done {
1553 Some(summary.text.clone())
1554 } else {
1555 None
1556 }
1557 } else {
1558 None
1559 };
1560 (path, summary)
1561 });
1562
1563 if let Some(summary) = summary {
1564 let conversation = this.read_with(&cx, |this, cx| this.serialize(cx));
1565 let path = if let Some(old_path) = old_path {
1566 old_path
1567 } else {
1568 let mut discriminant = 1;
1569 let mut new_path;
1570 loop {
1571 new_path = CONVERSATIONS_DIR.join(&format!(
1572 "{} - {}.zed.json",
1573 summary.trim(),
1574 discriminant
1575 ));
1576 if fs.is_file(&new_path).await {
1577 discriminant += 1;
1578 } else {
1579 break;
1580 }
1581 }
1582 new_path
1583 };
1584
1585 fs.create_dir(CONVERSATIONS_DIR.as_ref()).await?;
1586 fs.atomic_write(path.clone(), serde_json::to_string(&conversation).unwrap())
1587 .await?;
1588 this.update(&mut cx, |this, _| this.path = Some(path));
1589 }
1590
1591 Ok(())
1592 });
1593 }
1594}
1595
1596struct PendingCompletion {
1597 id: usize,
1598 _tasks: Vec<Task<()>>,
1599}
1600
1601enum ConversationEditorEvent {
1602 TabContentChanged,
1603}
1604
1605#[derive(Copy, Clone, Debug, PartialEq)]
1606struct ScrollPosition {
1607 offset_before_cursor: Vector2F,
1608 cursor: Anchor,
1609}
1610
1611struct ConversationEditor {
1612 conversation: ModelHandle<Conversation>,
1613 fs: Arc<dyn Fs>,
1614 editor: ViewHandle<Editor>,
1615 blocks: HashSet<BlockId>,
1616 scroll_position: Option<ScrollPosition>,
1617 _subscriptions: Vec<Subscription>,
1618}
1619
1620impl ConversationEditor {
1621 fn new(
1622 api_key: Rc<RefCell<Option<String>>>,
1623 language_registry: Arc<LanguageRegistry>,
1624 fs: Arc<dyn Fs>,
1625 cx: &mut ViewContext<Self>,
1626 ) -> Self {
1627 let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
1628 Self::for_conversation(conversation, fs, cx)
1629 }
1630
1631 fn for_conversation(
1632 conversation: ModelHandle<Conversation>,
1633 fs: Arc<dyn Fs>,
1634 cx: &mut ViewContext<Self>,
1635 ) -> Self {
1636 let editor = cx.add_view(|cx| {
1637 let mut editor = Editor::for_buffer(conversation.read(cx).buffer.clone(), None, cx);
1638 editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
1639 editor.set_show_gutter(false, cx);
1640 editor.set_show_wrap_guides(false, cx);
1641 editor
1642 });
1643
1644 let _subscriptions = vec![
1645 cx.observe(&conversation, |_, _, cx| cx.notify()),
1646 cx.subscribe(&conversation, Self::handle_conversation_event),
1647 cx.subscribe(&editor, Self::handle_editor_event),
1648 ];
1649
1650 let mut this = Self {
1651 conversation,
1652 editor,
1653 blocks: Default::default(),
1654 scroll_position: None,
1655 fs,
1656 _subscriptions,
1657 };
1658 this.update_message_headers(cx);
1659 this
1660 }
1661
1662 fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
1663 let cursors = self.cursors(cx);
1664
1665 let user_messages = self.conversation.update(cx, |conversation, cx| {
1666 let selected_messages = conversation
1667 .messages_for_offsets(cursors, cx)
1668 .into_iter()
1669 .map(|message| message.id)
1670 .collect();
1671 conversation.assist(selected_messages, cx)
1672 });
1673 let new_selections = user_messages
1674 .iter()
1675 .map(|message| {
1676 let cursor = message
1677 .start
1678 .to_offset(self.conversation.read(cx).buffer.read(cx));
1679 cursor..cursor
1680 })
1681 .collect::<Vec<_>>();
1682 if !new_selections.is_empty() {
1683 self.editor.update(cx, |editor, cx| {
1684 editor.change_selections(
1685 Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
1686 cx,
1687 |selections| selections.select_ranges(new_selections),
1688 );
1689 });
1690 // Avoid scrolling to the new cursor position so the assistant's output is stable.
1691 cx.defer(|this, _| this.scroll_position = None);
1692 }
1693 }
1694
1695 fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
1696 if !self
1697 .conversation
1698 .update(cx, |conversation, _| conversation.cancel_last_assist())
1699 {
1700 cx.propagate_action();
1701 }
1702 }
1703
1704 fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
1705 let cursors = self.cursors(cx);
1706 self.conversation.update(cx, |conversation, cx| {
1707 let messages = conversation
1708 .messages_for_offsets(cursors, cx)
1709 .into_iter()
1710 .map(|message| message.id)
1711 .collect();
1712 conversation.cycle_message_roles(messages, cx)
1713 });
1714 }
1715
1716 fn cursors(&self, cx: &AppContext) -> Vec<usize> {
1717 let selections = self.editor.read(cx).selections.all::<usize>(cx);
1718 selections
1719 .into_iter()
1720 .map(|selection| selection.head())
1721 .collect()
1722 }
1723
1724 fn handle_conversation_event(
1725 &mut self,
1726 _: ModelHandle<Conversation>,
1727 event: &ConversationEvent,
1728 cx: &mut ViewContext<Self>,
1729 ) {
1730 match event {
1731 ConversationEvent::MessagesEdited => {
1732 self.update_message_headers(cx);
1733 self.conversation.update(cx, |conversation, cx| {
1734 conversation.save(Some(Duration::from_millis(500)), self.fs.clone(), cx);
1735 });
1736 }
1737 ConversationEvent::SummaryChanged => {
1738 cx.emit(ConversationEditorEvent::TabContentChanged);
1739 self.conversation.update(cx, |conversation, cx| {
1740 conversation.save(None, self.fs.clone(), cx);
1741 });
1742 }
1743 ConversationEvent::StreamedCompletion => {
1744 self.editor.update(cx, |editor, cx| {
1745 if let Some(scroll_position) = self.scroll_position {
1746 let snapshot = editor.snapshot(cx);
1747 let cursor_point = scroll_position.cursor.to_display_point(&snapshot);
1748 let scroll_top =
1749 cursor_point.row() as f32 - scroll_position.offset_before_cursor.y();
1750 editor.set_scroll_position(
1751 vec2f(scroll_position.offset_before_cursor.x(), scroll_top),
1752 cx,
1753 );
1754 }
1755 });
1756 }
1757 }
1758 }
1759
1760 fn handle_editor_event(
1761 &mut self,
1762 _: ViewHandle<Editor>,
1763 event: &editor::Event,
1764 cx: &mut ViewContext<Self>,
1765 ) {
1766 match event {
1767 editor::Event::ScrollPositionChanged { autoscroll, .. } => {
1768 let cursor_scroll_position = self.cursor_scroll_position(cx);
1769 if *autoscroll {
1770 self.scroll_position = cursor_scroll_position;
1771 } else if self.scroll_position != cursor_scroll_position {
1772 self.scroll_position = None;
1773 }
1774 }
1775 editor::Event::SelectionsChanged { .. } => {
1776 self.scroll_position = self.cursor_scroll_position(cx);
1777 }
1778 _ => {}
1779 }
1780 }
1781
1782 fn cursor_scroll_position(&self, cx: &mut ViewContext<Self>) -> Option<ScrollPosition> {
1783 self.editor.update(cx, |editor, cx| {
1784 let snapshot = editor.snapshot(cx);
1785 let cursor = editor.selections.newest_anchor().head();
1786 let cursor_row = cursor.to_display_point(&snapshot.display_snapshot).row() as f32;
1787 let scroll_position = editor
1788 .scroll_manager
1789 .anchor()
1790 .scroll_position(&snapshot.display_snapshot);
1791
1792 let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1793 if (scroll_position.y()..scroll_bottom).contains(&cursor_row) {
1794 Some(ScrollPosition {
1795 cursor,
1796 offset_before_cursor: vec2f(
1797 scroll_position.x(),
1798 cursor_row - scroll_position.y(),
1799 ),
1800 })
1801 } else {
1802 None
1803 }
1804 })
1805 }
1806
1807 fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
1808 self.editor.update(cx, |editor, cx| {
1809 let buffer = editor.buffer().read(cx).snapshot(cx);
1810 let excerpt_id = *buffer.as_singleton().unwrap().0;
1811 let old_blocks = std::mem::take(&mut self.blocks);
1812 let new_blocks = self
1813 .conversation
1814 .read(cx)
1815 .messages(cx)
1816 .map(|message| BlockProperties {
1817 position: buffer.anchor_in_excerpt(excerpt_id, message.anchor),
1818 height: 2,
1819 style: BlockStyle::Sticky,
1820 render: Arc::new({
1821 let conversation = self.conversation.clone();
1822 // let metadata = message.metadata.clone();
1823 // let message = message.clone();
1824 move |cx| {
1825 enum Sender {}
1826 enum ErrorTooltip {}
1827
1828 let theme = theme::current(cx);
1829 let style = &theme.assistant;
1830 let message_id = message.id;
1831 let sender = MouseEventHandler::<Sender, _>::new(
1832 message_id.0,
1833 cx,
1834 |state, _| match message.role {
1835 Role::User => {
1836 let style = style.user_sender.style_for(state);
1837 Label::new("You", style.text.clone())
1838 .contained()
1839 .with_style(style.container)
1840 }
1841 Role::Assistant => {
1842 let style = style.assistant_sender.style_for(state);
1843 Label::new("Assistant", style.text.clone())
1844 .contained()
1845 .with_style(style.container)
1846 }
1847 Role::System => {
1848 let style = style.system_sender.style_for(state);
1849 Label::new("System", style.text.clone())
1850 .contained()
1851 .with_style(style.container)
1852 }
1853 },
1854 )
1855 .with_cursor_style(CursorStyle::PointingHand)
1856 .on_down(MouseButton::Left, {
1857 let conversation = conversation.clone();
1858 move |_, _, cx| {
1859 conversation.update(cx, |conversation, cx| {
1860 conversation.cycle_message_roles(
1861 HashSet::from_iter(Some(message_id)),
1862 cx,
1863 )
1864 })
1865 }
1866 });
1867
1868 Flex::row()
1869 .with_child(sender.aligned())
1870 .with_child(
1871 Label::new(
1872 message.sent_at.format("%I:%M%P").to_string(),
1873 style.sent_at.text.clone(),
1874 )
1875 .contained()
1876 .with_style(style.sent_at.container)
1877 .aligned(),
1878 )
1879 .with_children(
1880 if let MessageStatus::Error(error) = &message.status {
1881 Some(
1882 Svg::new("icons/circle_x_mark_12.svg")
1883 .with_color(style.error_icon.color)
1884 .constrained()
1885 .with_width(style.error_icon.width)
1886 .contained()
1887 .with_style(style.error_icon.container)
1888 .with_tooltip::<ErrorTooltip>(
1889 message_id.0,
1890 error.to_string(),
1891 None,
1892 theme.tooltip.clone(),
1893 cx,
1894 )
1895 .aligned(),
1896 )
1897 } else {
1898 None
1899 },
1900 )
1901 .aligned()
1902 .left()
1903 .contained()
1904 .with_style(style.message_header)
1905 .into_any()
1906 }
1907 }),
1908 disposition: BlockDisposition::Above,
1909 })
1910 .collect::<Vec<_>>();
1911
1912 editor.remove_blocks(old_blocks, None, cx);
1913 let ids = editor.insert_blocks(new_blocks, None, cx);
1914 self.blocks = HashSet::from_iter(ids);
1915 });
1916 }
1917
1918 fn quote_selection(
1919 workspace: &mut Workspace,
1920 _: &QuoteSelection,
1921 cx: &mut ViewContext<Workspace>,
1922 ) {
1923 let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1924 return;
1925 };
1926 let Some(editor) = workspace.active_item(cx).and_then(|item| item.act_as::<Editor>(cx)) else {
1927 return;
1928 };
1929
1930 let text = editor.read_with(cx, |editor, cx| {
1931 let range = editor.selections.newest::<usize>(cx).range();
1932 let buffer = editor.buffer().read(cx).snapshot(cx);
1933 let start_language = buffer.language_at(range.start);
1934 let end_language = buffer.language_at(range.end);
1935 let language_name = if start_language == end_language {
1936 start_language.map(|language| language.name())
1937 } else {
1938 None
1939 };
1940 let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1941
1942 let selected_text = buffer.text_for_range(range).collect::<String>();
1943 if selected_text.is_empty() {
1944 None
1945 } else {
1946 Some(if language_name == "markdown" {
1947 selected_text
1948 .lines()
1949 .map(|line| format!("> {}", line))
1950 .collect::<Vec<_>>()
1951 .join("\n")
1952 } else {
1953 format!("```{language_name}\n{selected_text}\n```")
1954 })
1955 }
1956 });
1957
1958 // Activate the panel
1959 if !panel.read(cx).has_focus(cx) {
1960 workspace.toggle_panel_focus::<AssistantPanel>(cx);
1961 }
1962
1963 if let Some(text) = text {
1964 panel.update(cx, |panel, cx| {
1965 let conversation = panel
1966 .active_editor()
1967 .cloned()
1968 .unwrap_or_else(|| panel.new_conversation(cx));
1969 conversation.update(cx, |conversation, cx| {
1970 conversation
1971 .editor
1972 .update(cx, |editor, cx| editor.insert(&text, cx))
1973 });
1974 });
1975 }
1976 }
1977
1978 fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1979 let editor = self.editor.read(cx);
1980 let conversation = self.conversation.read(cx);
1981 if editor.selections.count() == 1 {
1982 let selection = editor.selections.newest::<usize>(cx);
1983 let mut copied_text = String::new();
1984 let mut spanned_messages = 0;
1985 for message in conversation.messages(cx) {
1986 if message.offset_range.start >= selection.range().end {
1987 break;
1988 } else if message.offset_range.end >= selection.range().start {
1989 let range = cmp::max(message.offset_range.start, selection.range().start)
1990 ..cmp::min(message.offset_range.end, selection.range().end);
1991 if !range.is_empty() {
1992 spanned_messages += 1;
1993 write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
1994 for chunk in conversation.buffer.read(cx).text_for_range(range) {
1995 copied_text.push_str(&chunk);
1996 }
1997 copied_text.push('\n');
1998 }
1999 }
2000 }
2001
2002 if spanned_messages > 1 {
2003 cx.platform()
2004 .write_to_clipboard(ClipboardItem::new(copied_text));
2005 return;
2006 }
2007 }
2008
2009 cx.propagate_action();
2010 }
2011
2012 fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
2013 self.conversation.update(cx, |conversation, cx| {
2014 let selections = self.editor.read(cx).selections.disjoint_anchors();
2015 for selection in selections.into_iter() {
2016 let buffer = self.editor.read(cx).buffer().read(cx).snapshot(cx);
2017 let range = selection
2018 .map(|endpoint| endpoint.to_offset(&buffer))
2019 .range();
2020 conversation.split_message(range, cx);
2021 }
2022 });
2023 }
2024
2025 fn save(&mut self, _: &Save, cx: &mut ViewContext<Self>) {
2026 self.conversation.update(cx, |conversation, cx| {
2027 conversation.save(None, self.fs.clone(), cx)
2028 });
2029 }
2030
2031 fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
2032 self.conversation.update(cx, |conversation, cx| {
2033 let new_model = match conversation.model.as_str() {
2034 "gpt-4-0613" => "gpt-3.5-turbo-0613",
2035 _ => "gpt-4-0613",
2036 };
2037 conversation.set_model(new_model.into(), cx);
2038 });
2039 }
2040
2041 fn title(&self, cx: &AppContext) -> String {
2042 self.conversation
2043 .read(cx)
2044 .summary
2045 .as_ref()
2046 .map(|summary| summary.text.clone())
2047 .unwrap_or_else(|| "New Conversation".into())
2048 }
2049
2050 fn render_current_model(
2051 &self,
2052 style: &AssistantStyle,
2053 cx: &mut ViewContext<Self>,
2054 ) -> impl Element<Self> {
2055 enum Model {}
2056
2057 MouseEventHandler::<Model, _>::new(0, cx, |state, cx| {
2058 let style = style.model.style_for(state);
2059 Label::new(self.conversation.read(cx).model.clone(), style.text.clone())
2060 .contained()
2061 .with_style(style.container)
2062 })
2063 .with_cursor_style(CursorStyle::PointingHand)
2064 .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx))
2065 }
2066
2067 fn render_remaining_tokens(
2068 &self,
2069 style: &AssistantStyle,
2070 cx: &mut ViewContext<Self>,
2071 ) -> Option<impl Element<Self>> {
2072 let remaining_tokens = self.conversation.read(cx).remaining_tokens()?;
2073 let remaining_tokens_style = if remaining_tokens <= 0 {
2074 &style.no_remaining_tokens
2075 } else if remaining_tokens <= 500 {
2076 &style.low_remaining_tokens
2077 } else {
2078 &style.remaining_tokens
2079 };
2080 Some(
2081 Label::new(
2082 remaining_tokens.to_string(),
2083 remaining_tokens_style.text.clone(),
2084 )
2085 .contained()
2086 .with_style(remaining_tokens_style.container),
2087 )
2088 }
2089}
2090
2091impl Entity for ConversationEditor {
2092 type Event = ConversationEditorEvent;
2093}
2094
2095impl View for ConversationEditor {
2096 fn ui_name() -> &'static str {
2097 "ConversationEditor"
2098 }
2099
2100 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
2101 let theme = &theme::current(cx).assistant;
2102 Stack::new()
2103 .with_child(
2104 ChildView::new(&self.editor, cx)
2105 .contained()
2106 .with_style(theme.container),
2107 )
2108 .with_child(
2109 Flex::row()
2110 .with_child(self.render_current_model(theme, cx))
2111 .with_children(self.render_remaining_tokens(theme, cx))
2112 .aligned()
2113 .top()
2114 .right(),
2115 )
2116 .into_any()
2117 }
2118
2119 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
2120 if cx.is_self_focused() {
2121 cx.focus(&self.editor);
2122 }
2123 }
2124}
2125
2126#[derive(Clone, Debug)]
2127struct MessageAnchor {
2128 id: MessageId,
2129 start: language::Anchor,
2130}
2131
2132#[derive(Clone, Debug)]
2133pub struct Message {
2134 offset_range: Range<usize>,
2135 index_range: Range<usize>,
2136 id: MessageId,
2137 anchor: language::Anchor,
2138 role: Role,
2139 sent_at: DateTime<Local>,
2140 status: MessageStatus,
2141}
2142
2143impl Message {
2144 fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
2145 let mut content = format!("[Message {}]\n", self.id.0).to_string();
2146 content.extend(buffer.text_for_range(self.offset_range.clone()));
2147 RequestMessage {
2148 role: self.role,
2149 content: content.trim_end().into(),
2150 }
2151 }
2152}
2153
2154async fn stream_completion(
2155 api_key: String,
2156 executor: Arc<Background>,
2157 mut request: OpenAIRequest,
2158) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
2159 request.stream = true;
2160
2161 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
2162
2163 let json_data = serde_json::to_string(&request)?;
2164 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
2165 .header("Content-Type", "application/json")
2166 .header("Authorization", format!("Bearer {}", api_key))
2167 .body(json_data)?
2168 .send_async()
2169 .await?;
2170
2171 let status = response.status();
2172 if status == StatusCode::OK {
2173 executor
2174 .spawn(async move {
2175 let mut lines = BufReader::new(response.body_mut()).lines();
2176
2177 fn parse_line(
2178 line: Result<String, io::Error>,
2179 ) -> Result<Option<OpenAIResponseStreamEvent>> {
2180 if let Some(data) = line?.strip_prefix("data: ") {
2181 let event = serde_json::from_str(&data)?;
2182 Ok(Some(event))
2183 } else {
2184 Ok(None)
2185 }
2186 }
2187
2188 while let Some(line) = lines.next().await {
2189 if let Some(event) = parse_line(line).transpose() {
2190 let done = event.as_ref().map_or(false, |event| {
2191 event
2192 .choices
2193 .last()
2194 .map_or(false, |choice| choice.finish_reason.is_some())
2195 });
2196 if tx.unbounded_send(event).is_err() {
2197 break;
2198 }
2199
2200 if done {
2201 break;
2202 }
2203 }
2204 }
2205
2206 anyhow::Ok(())
2207 })
2208 .detach();
2209
2210 Ok(rx)
2211 } else {
2212 let mut body = String::new();
2213 response.body_mut().read_to_string(&mut body).await?;
2214
2215 #[derive(Deserialize)]
2216 struct OpenAIResponse {
2217 error: OpenAIError,
2218 }
2219
2220 #[derive(Deserialize)]
2221 struct OpenAIError {
2222 message: String,
2223 }
2224
2225 match serde_json::from_str::<OpenAIResponse>(&body) {
2226 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
2227 "Failed to connect to OpenAI API: {}",
2228 response.error.message,
2229 )),
2230
2231 _ => Err(anyhow!(
2232 "Failed to connect to OpenAI API: {} {}",
2233 response.status(),
2234 body,
2235 )),
2236 }
2237 }
2238}
2239
2240#[cfg(test)]
2241mod tests {
2242 use super::*;
2243 use crate::MessageId;
2244 use gpui::AppContext;
2245
2246 #[gpui::test]
2247 fn test_inserting_and_removing_messages(cx: &mut AppContext) {
2248 let registry = Arc::new(LanguageRegistry::test());
2249 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2250 let buffer = conversation.read(cx).buffer.clone();
2251
2252 let message_1 = conversation.read(cx).message_anchors[0].clone();
2253 assert_eq!(
2254 messages(&conversation, cx),
2255 vec![(message_1.id, Role::User, 0..0)]
2256 );
2257
2258 let message_2 = conversation.update(cx, |conversation, cx| {
2259 conversation
2260 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
2261 .unwrap()
2262 });
2263 assert_eq!(
2264 messages(&conversation, cx),
2265 vec![
2266 (message_1.id, Role::User, 0..1),
2267 (message_2.id, Role::Assistant, 1..1)
2268 ]
2269 );
2270
2271 buffer.update(cx, |buffer, cx| {
2272 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
2273 });
2274 assert_eq!(
2275 messages(&conversation, cx),
2276 vec![
2277 (message_1.id, Role::User, 0..2),
2278 (message_2.id, Role::Assistant, 2..3)
2279 ]
2280 );
2281
2282 let message_3 = conversation.update(cx, |conversation, cx| {
2283 conversation
2284 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2285 .unwrap()
2286 });
2287 assert_eq!(
2288 messages(&conversation, cx),
2289 vec![
2290 (message_1.id, Role::User, 0..2),
2291 (message_2.id, Role::Assistant, 2..4),
2292 (message_3.id, Role::User, 4..4)
2293 ]
2294 );
2295
2296 let message_4 = conversation.update(cx, |conversation, cx| {
2297 conversation
2298 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2299 .unwrap()
2300 });
2301 assert_eq!(
2302 messages(&conversation, cx),
2303 vec![
2304 (message_1.id, Role::User, 0..2),
2305 (message_2.id, Role::Assistant, 2..4),
2306 (message_4.id, Role::User, 4..5),
2307 (message_3.id, Role::User, 5..5),
2308 ]
2309 );
2310
2311 buffer.update(cx, |buffer, cx| {
2312 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
2313 });
2314 assert_eq!(
2315 messages(&conversation, cx),
2316 vec![
2317 (message_1.id, Role::User, 0..2),
2318 (message_2.id, Role::Assistant, 2..4),
2319 (message_4.id, Role::User, 4..6),
2320 (message_3.id, Role::User, 6..7),
2321 ]
2322 );
2323
2324 // Deleting across message boundaries merges the messages.
2325 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
2326 assert_eq!(
2327 messages(&conversation, cx),
2328 vec![
2329 (message_1.id, Role::User, 0..3),
2330 (message_3.id, Role::User, 3..4),
2331 ]
2332 );
2333
2334 // Undoing the deletion should also undo the merge.
2335 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2336 assert_eq!(
2337 messages(&conversation, cx),
2338 vec![
2339 (message_1.id, Role::User, 0..2),
2340 (message_2.id, Role::Assistant, 2..4),
2341 (message_4.id, Role::User, 4..6),
2342 (message_3.id, Role::User, 6..7),
2343 ]
2344 );
2345
2346 // Redoing the deletion should also redo the merge.
2347 buffer.update(cx, |buffer, cx| buffer.redo(cx));
2348 assert_eq!(
2349 messages(&conversation, cx),
2350 vec![
2351 (message_1.id, Role::User, 0..3),
2352 (message_3.id, Role::User, 3..4),
2353 ]
2354 );
2355
2356 // Ensure we can still insert after a merged message.
2357 let message_5 = conversation.update(cx, |conversation, cx| {
2358 conversation
2359 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2360 .unwrap()
2361 });
2362 assert_eq!(
2363 messages(&conversation, cx),
2364 vec![
2365 (message_1.id, Role::User, 0..3),
2366 (message_5.id, Role::System, 3..4),
2367 (message_3.id, Role::User, 4..5)
2368 ]
2369 );
2370 }
2371
2372 #[gpui::test]
2373 fn test_message_splitting(cx: &mut AppContext) {
2374 let registry = Arc::new(LanguageRegistry::test());
2375 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2376 let buffer = conversation.read(cx).buffer.clone();
2377
2378 let message_1 = conversation.read(cx).message_anchors[0].clone();
2379 assert_eq!(
2380 messages(&conversation, cx),
2381 vec![(message_1.id, Role::User, 0..0)]
2382 );
2383
2384 buffer.update(cx, |buffer, cx| {
2385 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
2386 });
2387
2388 let (_, message_2) =
2389 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2390 let message_2 = message_2.unwrap();
2391
2392 // We recycle newlines in the middle of a split message
2393 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
2394 assert_eq!(
2395 messages(&conversation, cx),
2396 vec![
2397 (message_1.id, Role::User, 0..4),
2398 (message_2.id, Role::User, 4..16),
2399 ]
2400 );
2401
2402 let (_, message_3) =
2403 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2404 let message_3 = message_3.unwrap();
2405
2406 // We don't recycle newlines at the end of a split message
2407 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2408 assert_eq!(
2409 messages(&conversation, cx),
2410 vec![
2411 (message_1.id, Role::User, 0..4),
2412 (message_3.id, Role::User, 4..5),
2413 (message_2.id, Role::User, 5..17),
2414 ]
2415 );
2416
2417 let (_, message_4) =
2418 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2419 let message_4 = message_4.unwrap();
2420 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2421 assert_eq!(
2422 messages(&conversation, cx),
2423 vec![
2424 (message_1.id, Role::User, 0..4),
2425 (message_3.id, Role::User, 4..5),
2426 (message_2.id, Role::User, 5..9),
2427 (message_4.id, Role::User, 9..17),
2428 ]
2429 );
2430
2431 let (_, message_5) =
2432 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2433 let message_5 = message_5.unwrap();
2434 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
2435 assert_eq!(
2436 messages(&conversation, cx),
2437 vec![
2438 (message_1.id, Role::User, 0..4),
2439 (message_3.id, Role::User, 4..5),
2440 (message_2.id, Role::User, 5..9),
2441 (message_4.id, Role::User, 9..10),
2442 (message_5.id, Role::User, 10..18),
2443 ]
2444 );
2445
2446 let (message_6, message_7) = conversation.update(cx, |conversation, cx| {
2447 conversation.split_message(14..16, cx)
2448 });
2449 let message_6 = message_6.unwrap();
2450 let message_7 = message_7.unwrap();
2451 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
2452 assert_eq!(
2453 messages(&conversation, cx),
2454 vec![
2455 (message_1.id, Role::User, 0..4),
2456 (message_3.id, Role::User, 4..5),
2457 (message_2.id, Role::User, 5..9),
2458 (message_4.id, Role::User, 9..10),
2459 (message_5.id, Role::User, 10..14),
2460 (message_6.id, Role::User, 14..17),
2461 (message_7.id, Role::User, 17..19),
2462 ]
2463 );
2464 }
2465
2466 #[gpui::test]
2467 fn test_messages_for_offsets(cx: &mut AppContext) {
2468 let registry = Arc::new(LanguageRegistry::test());
2469 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2470 let buffer = conversation.read(cx).buffer.clone();
2471
2472 let message_1 = conversation.read(cx).message_anchors[0].clone();
2473 assert_eq!(
2474 messages(&conversation, cx),
2475 vec![(message_1.id, Role::User, 0..0)]
2476 );
2477
2478 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
2479 let message_2 = conversation
2480 .update(cx, |conversation, cx| {
2481 conversation.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
2482 })
2483 .unwrap();
2484 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
2485
2486 let message_3 = conversation
2487 .update(cx, |conversation, cx| {
2488 conversation.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2489 })
2490 .unwrap();
2491 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
2492
2493 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
2494 assert_eq!(
2495 messages(&conversation, cx),
2496 vec![
2497 (message_1.id, Role::User, 0..4),
2498 (message_2.id, Role::User, 4..8),
2499 (message_3.id, Role::User, 8..11)
2500 ]
2501 );
2502
2503 assert_eq!(
2504 message_ids_for_offsets(&conversation, &[0, 4, 9], cx),
2505 [message_1.id, message_2.id, message_3.id]
2506 );
2507 assert_eq!(
2508 message_ids_for_offsets(&conversation, &[0, 1, 11], cx),
2509 [message_1.id, message_3.id]
2510 );
2511
2512 let message_4 = conversation
2513 .update(cx, |conversation, cx| {
2514 conversation.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
2515 })
2516 .unwrap();
2517 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
2518 assert_eq!(
2519 messages(&conversation, cx),
2520 vec![
2521 (message_1.id, Role::User, 0..4),
2522 (message_2.id, Role::User, 4..8),
2523 (message_3.id, Role::User, 8..12),
2524 (message_4.id, Role::User, 12..12)
2525 ]
2526 );
2527 assert_eq!(
2528 message_ids_for_offsets(&conversation, &[0, 4, 8, 12], cx),
2529 [message_1.id, message_2.id, message_3.id, message_4.id]
2530 );
2531
2532 fn message_ids_for_offsets(
2533 conversation: &ModelHandle<Conversation>,
2534 offsets: &[usize],
2535 cx: &AppContext,
2536 ) -> Vec<MessageId> {
2537 conversation
2538 .read(cx)
2539 .messages_for_offsets(offsets.iter().copied(), cx)
2540 .into_iter()
2541 .map(|message| message.id)
2542 .collect()
2543 }
2544 }
2545
2546 #[gpui::test]
2547 fn test_serialization(cx: &mut AppContext) {
2548 let registry = Arc::new(LanguageRegistry::test());
2549 let conversation =
2550 cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
2551 let buffer = conversation.read(cx).buffer.clone();
2552 let message_0 = conversation.read(cx).message_anchors[0].id;
2553 let message_1 = conversation.update(cx, |conversation, cx| {
2554 conversation
2555 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
2556 .unwrap()
2557 });
2558 let message_2 = conversation.update(cx, |conversation, cx| {
2559 conversation
2560 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2561 .unwrap()
2562 });
2563 buffer.update(cx, |buffer, cx| {
2564 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
2565 buffer.finalize_last_transaction();
2566 });
2567 let _message_3 = conversation.update(cx, |conversation, cx| {
2568 conversation
2569 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
2570 .unwrap()
2571 });
2572 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2573 assert_eq!(buffer.read(cx).text(), "a\nb\nc\n");
2574 assert_eq!(
2575 messages(&conversation, cx),
2576 [
2577 (message_0, Role::User, 0..2),
2578 (message_1.id, Role::Assistant, 2..6),
2579 (message_2.id, Role::System, 6..6),
2580 ]
2581 );
2582
2583 let deserialized_conversation = cx.add_model(|cx| {
2584 Conversation::deserialize(
2585 conversation.read(cx).serialize(cx),
2586 Default::default(),
2587 Default::default(),
2588 registry.clone(),
2589 cx,
2590 )
2591 });
2592 let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone();
2593 assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n");
2594 assert_eq!(
2595 messages(&deserialized_conversation, cx),
2596 [
2597 (message_0, Role::User, 0..2),
2598 (message_1.id, Role::Assistant, 2..6),
2599 (message_2.id, Role::System, 6..6),
2600 ]
2601 );
2602 }
2603
2604 fn messages(
2605 conversation: &ModelHandle<Conversation>,
2606 cx: &AppContext,
2607 ) -> Vec<(MessageId, Role, Range<usize>)> {
2608 conversation
2609 .read(cx)
2610 .messages(cx)
2611 .map(|message| (message.id, message.role, message.offset_range))
2612 .collect()
2613 }
2614}