1use crate::{
2 assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
3 MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
4 RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
5};
6use anyhow::{anyhow, Result};
7use chrono::{DateTime, Local};
8use collections::{HashMap, HashSet};
9use editor::{
10 display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
11 scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
12 Anchor, Editor, ToOffset,
13};
14use fs::Fs;
15use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
16use gpui::{
17 actions,
18 elements::*,
19 executor::Background,
20 geometry::vector::{vec2f, Vector2F},
21 platform::{CursorStyle, MouseButton},
22 Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
23 Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
24};
25use isahc::{http::StatusCode, Request, RequestExt};
26use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
27use search::BufferSearchBar;
28use serde::Deserialize;
29use settings::SettingsStore;
30use std::{
31 cell::RefCell,
32 cmp, env,
33 fmt::Write,
34 io, iter,
35 ops::Range,
36 path::{Path, PathBuf},
37 rc::Rc,
38 sync::Arc,
39 time::Duration,
40};
41use theme::AssistantStyle;
42use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
43use workspace::{
44 dock::{DockPosition, Panel},
45 searchable::Direction,
46 Save, ToggleZoom, Toolbar, Workspace,
47};
48
49const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
50
51actions!(
52 assistant,
53 [
54 NewConversation,
55 Assist,
56 Split,
57 CycleMessageRole,
58 QuoteSelection,
59 ToggleFocus,
60 ResetKey,
61 ]
62);
63
64pub fn init(cx: &mut AppContext) {
65 settings::register::<AssistantSettings>(cx);
66 cx.add_action(
67 |this: &mut AssistantPanel,
68 _: &workspace::NewFile,
69 cx: &mut ViewContext<AssistantPanel>| {
70 this.new_conversation(cx);
71 },
72 );
73 cx.add_action(ConversationEditor::assist);
74 cx.capture_action(ConversationEditor::cancel_last_assist);
75 cx.capture_action(ConversationEditor::save);
76 cx.add_action(ConversationEditor::quote_selection);
77 cx.capture_action(ConversationEditor::copy);
78 cx.add_action(ConversationEditor::split);
79 cx.capture_action(ConversationEditor::cycle_message_role);
80 cx.add_action(AssistantPanel::save_api_key);
81 cx.add_action(AssistantPanel::reset_api_key);
82 cx.add_action(AssistantPanel::toggle_zoom);
83 cx.add_action(AssistantPanel::deploy);
84 cx.add_action(AssistantPanel::select_next_match);
85 cx.add_action(AssistantPanel::select_prev_match);
86 cx.add_action(AssistantPanel::handle_editor_cancel);
87 cx.add_action(
88 |workspace: &mut Workspace, _: &ToggleFocus, cx: &mut ViewContext<Workspace>| {
89 workspace.toggle_panel_focus::<AssistantPanel>(cx);
90 },
91 );
92}
93
94#[derive(Debug)]
95pub enum AssistantPanelEvent {
96 ZoomIn,
97 ZoomOut,
98 Focus,
99 Close,
100 DockPositionChanged,
101}
102
103pub struct AssistantPanel {
104 workspace: WeakViewHandle<Workspace>,
105 width: Option<f32>,
106 height: Option<f32>,
107 active_editor_index: Option<usize>,
108 prev_active_editor_index: Option<usize>,
109 editors: Vec<ViewHandle<ConversationEditor>>,
110 saved_conversations: Vec<SavedConversationMetadata>,
111 saved_conversations_list_state: UniformListState,
112 zoomed: bool,
113 has_focus: bool,
114 toolbar: ViewHandle<Toolbar>,
115 api_key: Rc<RefCell<Option<String>>>,
116 api_key_editor: Option<ViewHandle<Editor>>,
117 has_read_credentials: bool,
118 languages: Arc<LanguageRegistry>,
119 fs: Arc<dyn Fs>,
120 subscriptions: Vec<Subscription>,
121 _watch_saved_conversations: Task<Result<()>>,
122}
123
124impl AssistantPanel {
125 pub fn load(
126 workspace: WeakViewHandle<Workspace>,
127 cx: AsyncAppContext,
128 ) -> Task<Result<ViewHandle<Self>>> {
129 cx.spawn(|mut cx| async move {
130 let fs = workspace.read_with(&cx, |workspace, _| workspace.app_state().fs.clone())?;
131 let saved_conversations = SavedConversationMetadata::list(fs.clone())
132 .await
133 .log_err()
134 .unwrap_or_default();
135
136 // TODO: deserialize state.
137 let workspace_handle = workspace.clone();
138 workspace.update(&mut cx, |workspace, cx| {
139 cx.add_view::<Self, _>(|cx| {
140 const CONVERSATION_WATCH_DURATION: Duration = Duration::from_millis(100);
141 let _watch_saved_conversations = cx.spawn(move |this, mut cx| async move {
142 let mut events = fs
143 .watch(&CONVERSATIONS_DIR, CONVERSATION_WATCH_DURATION)
144 .await;
145 while events.next().await.is_some() {
146 let saved_conversations = SavedConversationMetadata::list(fs.clone())
147 .await
148 .log_err()
149 .unwrap_or_default();
150 this.update(&mut cx, |this, cx| {
151 this.saved_conversations = saved_conversations;
152 cx.notify();
153 })
154 .ok();
155 }
156
157 anyhow::Ok(())
158 });
159
160 let toolbar = cx.add_view(|cx| {
161 let mut toolbar = Toolbar::new();
162 toolbar.set_can_navigate(false, cx);
163 toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
164 toolbar
165 });
166 let mut this = Self {
167 workspace: workspace_handle,
168 active_editor_index: Default::default(),
169 prev_active_editor_index: Default::default(),
170 editors: Default::default(),
171 saved_conversations,
172 saved_conversations_list_state: Default::default(),
173 zoomed: false,
174 has_focus: false,
175 toolbar,
176 api_key: Rc::new(RefCell::new(None)),
177 api_key_editor: None,
178 has_read_credentials: false,
179 languages: workspace.app_state().languages.clone(),
180 fs: workspace.app_state().fs.clone(),
181 width: None,
182 height: None,
183 subscriptions: Default::default(),
184 _watch_saved_conversations,
185 };
186
187 let mut old_dock_position = this.position(cx);
188 this.subscriptions =
189 vec![cx.observe_global::<SettingsStore, _>(move |this, cx| {
190 let new_dock_position = this.position(cx);
191 if new_dock_position != old_dock_position {
192 old_dock_position = new_dock_position;
193 cx.emit(AssistantPanelEvent::DockPositionChanged);
194 }
195 cx.notify();
196 })];
197
198 this
199 })
200 })
201 })
202 }
203
204 fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
205 let editor = cx.add_view(|cx| {
206 ConversationEditor::new(
207 self.api_key.clone(),
208 self.languages.clone(),
209 self.fs.clone(),
210 cx,
211 )
212 });
213 self.add_conversation(editor.clone(), cx);
214 editor
215 }
216
217 fn add_conversation(
218 &mut self,
219 editor: ViewHandle<ConversationEditor>,
220 cx: &mut ViewContext<Self>,
221 ) {
222 self.subscriptions
223 .push(cx.subscribe(&editor, Self::handle_conversation_editor_event));
224
225 let conversation = editor.read(cx).conversation.clone();
226 self.subscriptions
227 .push(cx.observe(&conversation, |_, _, cx| cx.notify()));
228
229 let index = self.editors.len();
230 self.editors.push(editor);
231 self.set_active_editor_index(Some(index), cx);
232 }
233
234 fn set_active_editor_index(&mut self, index: Option<usize>, cx: &mut ViewContext<Self>) {
235 self.prev_active_editor_index = self.active_editor_index;
236 self.active_editor_index = index;
237 if let Some(editor) = self.active_editor() {
238 let editor = editor.read(cx).editor.clone();
239 self.toolbar.update(cx, |toolbar, cx| {
240 toolbar.set_active_item(Some(&editor), cx);
241 });
242 if self.has_focus(cx) {
243 cx.focus(&editor);
244 }
245 } else {
246 self.toolbar.update(cx, |toolbar, cx| {
247 toolbar.set_active_item(None, cx);
248 });
249 }
250
251 cx.notify();
252 }
253
254 fn handle_conversation_editor_event(
255 &mut self,
256 _: ViewHandle<ConversationEditor>,
257 event: &ConversationEditorEvent,
258 cx: &mut ViewContext<Self>,
259 ) {
260 match event {
261 ConversationEditorEvent::TabContentChanged => cx.notify(),
262 }
263 }
264
265 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
266 if let Some(api_key) = self
267 .api_key_editor
268 .as_ref()
269 .map(|editor| editor.read(cx).text(cx))
270 {
271 if !api_key.is_empty() {
272 cx.platform()
273 .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
274 .log_err();
275 *self.api_key.borrow_mut() = Some(api_key);
276 self.api_key_editor.take();
277 cx.focus_self();
278 cx.notify();
279 }
280 } else {
281 cx.propagate_action();
282 }
283 }
284
285 fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
286 cx.platform().delete_credentials(OPENAI_API_URL).log_err();
287 self.api_key.take();
288 self.api_key_editor = Some(build_api_key_editor(cx));
289 cx.focus_self();
290 cx.notify();
291 }
292
293 fn toggle_zoom(&mut self, _: &workspace::ToggleZoom, cx: &mut ViewContext<Self>) {
294 if self.zoomed {
295 cx.emit(AssistantPanelEvent::ZoomOut)
296 } else {
297 cx.emit(AssistantPanelEvent::ZoomIn)
298 }
299 }
300
301 fn deploy(&mut self, action: &search::buffer_search::Deploy, cx: &mut ViewContext<Self>) {
302 let mut propagate_action = true;
303 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
304 search_bar.update(cx, |search_bar, cx| {
305 if search_bar.show(cx) {
306 search_bar.search_suggested(cx);
307 if action.focus {
308 search_bar.select_query(cx);
309 cx.focus_self();
310 }
311 propagate_action = false
312 }
313 });
314 }
315 if propagate_action {
316 cx.propagate_action();
317 }
318 }
319
320 fn handle_editor_cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
321 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
322 if !search_bar.read(cx).is_dismissed() {
323 search_bar.update(cx, |search_bar, cx| {
324 search_bar.dismiss(&Default::default(), cx)
325 });
326 return;
327 }
328 }
329 cx.propagate_action();
330 }
331
332 fn select_next_match(&mut self, _: &search::SelectNextMatch, cx: &mut ViewContext<Self>) {
333 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
334 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Next, 1, cx));
335 }
336 }
337
338 fn select_prev_match(&mut self, _: &search::SelectPrevMatch, cx: &mut ViewContext<Self>) {
339 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
340 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Prev, 1, cx));
341 }
342 }
343
344 fn active_editor(&self) -> Option<&ViewHandle<ConversationEditor>> {
345 self.editors.get(self.active_editor_index?)
346 }
347
348 fn render_hamburger_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
349 enum History {}
350 let theme = theme::current(cx);
351 let tooltip_style = theme::current(cx).tooltip.clone();
352 MouseEventHandler::new::<History, _>(0, cx, |state, _| {
353 let style = theme.assistant.hamburger_button.style_for(state);
354 Svg::for_style(style.icon.clone())
355 .contained()
356 .with_style(style.container)
357 })
358 .with_cursor_style(CursorStyle::PointingHand)
359 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
360 if this.active_editor().is_some() {
361 this.set_active_editor_index(None, cx);
362 } else {
363 this.set_active_editor_index(this.prev_active_editor_index, cx);
364 }
365 })
366 .with_tooltip::<History>(1, "History", None, tooltip_style, cx)
367 }
368
369 fn render_editor_tools(&self, cx: &mut ViewContext<Self>) -> Vec<AnyElement<Self>> {
370 if self.active_editor().is_some() {
371 vec![
372 Self::render_split_button(cx).into_any(),
373 Self::render_quote_button(cx).into_any(),
374 Self::render_assist_button(cx).into_any(),
375 ]
376 } else {
377 Default::default()
378 }
379 }
380
381 fn render_split_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
382 let theme = theme::current(cx);
383 let tooltip_style = theme::current(cx).tooltip.clone();
384 MouseEventHandler::new::<Split, _>(0, cx, |state, _| {
385 let style = theme.assistant.split_button.style_for(state);
386 Svg::for_style(style.icon.clone())
387 .contained()
388 .with_style(style.container)
389 })
390 .with_cursor_style(CursorStyle::PointingHand)
391 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
392 if let Some(active_editor) = this.active_editor() {
393 active_editor.update(cx, |editor, cx| editor.split(&Default::default(), cx));
394 }
395 })
396 .with_tooltip::<Split>(
397 1,
398 "Split Message",
399 Some(Box::new(Split)),
400 tooltip_style,
401 cx,
402 )
403 }
404
405 fn render_assist_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
406 let theme = theme::current(cx);
407 let tooltip_style = theme::current(cx).tooltip.clone();
408 MouseEventHandler::new::<Assist, _>(0, cx, |state, _| {
409 let style = theme.assistant.assist_button.style_for(state);
410 Svg::for_style(style.icon.clone())
411 .contained()
412 .with_style(style.container)
413 })
414 .with_cursor_style(CursorStyle::PointingHand)
415 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
416 if let Some(active_editor) = this.active_editor() {
417 active_editor.update(cx, |editor, cx| editor.assist(&Default::default(), cx));
418 }
419 })
420 .with_tooltip::<Assist>(1, "Assist", Some(Box::new(Assist)), tooltip_style, cx)
421 }
422
423 fn render_quote_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
424 let theme = theme::current(cx);
425 let tooltip_style = theme::current(cx).tooltip.clone();
426 MouseEventHandler::new::<QuoteSelection, _>(0, cx, |state, _| {
427 let style = theme.assistant.quote_button.style_for(state);
428 Svg::for_style(style.icon.clone())
429 .contained()
430 .with_style(style.container)
431 })
432 .with_cursor_style(CursorStyle::PointingHand)
433 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
434 if let Some(workspace) = this.workspace.upgrade(cx) {
435 cx.window_context().defer(move |cx| {
436 workspace.update(cx, |workspace, cx| {
437 ConversationEditor::quote_selection(workspace, &Default::default(), cx)
438 });
439 });
440 }
441 })
442 .with_tooltip::<QuoteSelection>(
443 1,
444 "Quote Selection",
445 Some(Box::new(QuoteSelection)),
446 tooltip_style,
447 cx,
448 )
449 }
450
451 fn render_plus_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
452 let theme = theme::current(cx);
453 let tooltip_style = theme::current(cx).tooltip.clone();
454 MouseEventHandler::new::<NewConversation, _>(0, cx, |state, _| {
455 let style = theme.assistant.plus_button.style_for(state);
456 Svg::for_style(style.icon.clone())
457 .contained()
458 .with_style(style.container)
459 })
460 .with_cursor_style(CursorStyle::PointingHand)
461 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
462 this.new_conversation(cx);
463 })
464 .with_tooltip::<NewConversation>(
465 1,
466 "New Conversation",
467 Some(Box::new(NewConversation)),
468 tooltip_style,
469 cx,
470 )
471 }
472
473 fn render_zoom_button(&self, cx: &mut ViewContext<Self>) -> impl Element<Self> {
474 enum ToggleZoomButton {}
475
476 let theme = theme::current(cx);
477 let tooltip_style = theme::current(cx).tooltip.clone();
478 let style = if self.zoomed {
479 &theme.assistant.zoom_out_button
480 } else {
481 &theme.assistant.zoom_in_button
482 };
483
484 MouseEventHandler::new::<ToggleZoomButton, _>(0, cx, |state, _| {
485 let style = style.style_for(state);
486 Svg::for_style(style.icon.clone())
487 .contained()
488 .with_style(style.container)
489 })
490 .with_cursor_style(CursorStyle::PointingHand)
491 .on_click(MouseButton::Left, |_, this, cx| {
492 this.toggle_zoom(&ToggleZoom, cx);
493 })
494 .with_tooltip::<ToggleZoom>(
495 0,
496 if self.zoomed { "Zoom Out" } else { "Zoom In" },
497 Some(Box::new(ToggleZoom)),
498 tooltip_style,
499 cx,
500 )
501 }
502
503 fn render_saved_conversation(
504 &mut self,
505 index: usize,
506 cx: &mut ViewContext<Self>,
507 ) -> impl Element<Self> {
508 let conversation = &self.saved_conversations[index];
509 let path = conversation.path.clone();
510 MouseEventHandler::new::<SavedConversationMetadata, _>(index, cx, move |state, cx| {
511 let style = &theme::current(cx).assistant.saved_conversation;
512 Flex::row()
513 .with_child(
514 Label::new(
515 conversation.mtime.format("%F %I:%M%p").to_string(),
516 style.saved_at.text.clone(),
517 )
518 .aligned()
519 .contained()
520 .with_style(style.saved_at.container),
521 )
522 .with_child(
523 Label::new(conversation.title.clone(), style.title.text.clone())
524 .aligned()
525 .contained()
526 .with_style(style.title.container),
527 )
528 .contained()
529 .with_style(*style.container.style_for(state))
530 })
531 .with_cursor_style(CursorStyle::PointingHand)
532 .on_click(MouseButton::Left, move |_, this, cx| {
533 this.open_conversation(path.clone(), cx)
534 .detach_and_log_err(cx)
535 })
536 }
537
538 fn open_conversation(&mut self, path: PathBuf, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
539 if let Some(ix) = self.editor_index_for_path(&path, cx) {
540 self.set_active_editor_index(Some(ix), cx);
541 return Task::ready(Ok(()));
542 }
543
544 let fs = self.fs.clone();
545 let api_key = self.api_key.clone();
546 let languages = self.languages.clone();
547 cx.spawn(|this, mut cx| async move {
548 let saved_conversation = fs.load(&path).await?;
549 let saved_conversation = serde_json::from_str(&saved_conversation)?;
550 let conversation = cx.add_model(|cx| {
551 Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
552 });
553 this.update(&mut cx, |this, cx| {
554 // If, by the time we've loaded the conversation, the user has already opened
555 // the same conversation, we don't want to open it again.
556 if let Some(ix) = this.editor_index_for_path(&path, cx) {
557 this.set_active_editor_index(Some(ix), cx);
558 } else {
559 let editor = cx
560 .add_view(|cx| ConversationEditor::for_conversation(conversation, fs, cx));
561 this.add_conversation(editor, cx);
562 }
563 })?;
564 Ok(())
565 })
566 }
567
568 fn editor_index_for_path(&self, path: &Path, cx: &AppContext) -> Option<usize> {
569 self.editors
570 .iter()
571 .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
572 }
573}
574
575fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
576 cx.add_view(|cx| {
577 let mut editor = Editor::single_line(
578 Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
579 cx,
580 );
581 editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
582 editor
583 })
584}
585
586impl Entity for AssistantPanel {
587 type Event = AssistantPanelEvent;
588}
589
590impl View for AssistantPanel {
591 fn ui_name() -> &'static str {
592 "AssistantPanel"
593 }
594
595 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
596 let theme = &theme::current(cx);
597 let style = &theme.assistant;
598 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
599 Flex::column()
600 .with_child(
601 Text::new(
602 "Paste your OpenAI API key and press Enter to use the assistant",
603 style.api_key_prompt.text.clone(),
604 )
605 .aligned(),
606 )
607 .with_child(
608 ChildView::new(api_key_editor, cx)
609 .contained()
610 .with_style(style.api_key_editor.container)
611 .aligned(),
612 )
613 .contained()
614 .with_style(style.api_key_prompt.container)
615 .aligned()
616 .into_any()
617 } else {
618 let title = self.active_editor().map(|editor| {
619 Label::new(editor.read(cx).title(cx), style.title.text.clone())
620 .contained()
621 .with_style(style.title.container)
622 .aligned()
623 .left()
624 .flex(1., false)
625 });
626 let mut header = Flex::row()
627 .with_child(Self::render_hamburger_button(cx).aligned())
628 .with_children(title);
629 if self.has_focus {
630 header.add_children(
631 self.render_editor_tools(cx)
632 .into_iter()
633 .map(|tool| tool.aligned().flex_float()),
634 );
635 header.add_child(Self::render_plus_button(cx).aligned().flex_float());
636 header.add_child(self.render_zoom_button(cx).aligned());
637 }
638
639 Flex::column()
640 .with_child(
641 header
642 .contained()
643 .with_style(theme.workspace.tab_bar.container)
644 .expanded()
645 .constrained()
646 .with_height(theme.workspace.tab_bar.height),
647 )
648 .with_children(if self.toolbar.read(cx).hidden() {
649 None
650 } else {
651 Some(ChildView::new(&self.toolbar, cx).expanded())
652 })
653 .with_child(if let Some(editor) = self.active_editor() {
654 ChildView::new(editor, cx).flex(1., true).into_any()
655 } else {
656 UniformList::new(
657 self.saved_conversations_list_state.clone(),
658 self.saved_conversations.len(),
659 cx,
660 |this, range, items, cx| {
661 for ix in range {
662 items.push(this.render_saved_conversation(ix, cx).into_any());
663 }
664 },
665 )
666 .flex(1., true)
667 .into_any()
668 })
669 .into_any()
670 }
671 }
672
673 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
674 self.has_focus = true;
675 self.toolbar
676 .update(cx, |toolbar, cx| toolbar.focus_changed(true, cx));
677 cx.notify();
678 if cx.is_self_focused() {
679 if let Some(editor) = self.active_editor() {
680 cx.focus(editor);
681 } else if let Some(api_key_editor) = self.api_key_editor.as_ref() {
682 cx.focus(api_key_editor);
683 }
684 }
685 }
686
687 fn focus_out(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
688 self.has_focus = false;
689 self.toolbar
690 .update(cx, |toolbar, cx| toolbar.focus_changed(false, cx));
691 cx.notify();
692 }
693}
694
695impl Panel for AssistantPanel {
696 fn position(&self, cx: &WindowContext) -> DockPosition {
697 match settings::get::<AssistantSettings>(cx).dock {
698 AssistantDockPosition::Left => DockPosition::Left,
699 AssistantDockPosition::Bottom => DockPosition::Bottom,
700 AssistantDockPosition::Right => DockPosition::Right,
701 }
702 }
703
704 fn position_is_valid(&self, _: DockPosition) -> bool {
705 true
706 }
707
708 fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
709 settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
710 let dock = match position {
711 DockPosition::Left => AssistantDockPosition::Left,
712 DockPosition::Bottom => AssistantDockPosition::Bottom,
713 DockPosition::Right => AssistantDockPosition::Right,
714 };
715 settings.dock = Some(dock);
716 });
717 }
718
719 fn size(&self, cx: &WindowContext) -> f32 {
720 let settings = settings::get::<AssistantSettings>(cx);
721 match self.position(cx) {
722 DockPosition::Left | DockPosition::Right => {
723 self.width.unwrap_or_else(|| settings.default_width)
724 }
725 DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
726 }
727 }
728
729 fn set_size(&mut self, size: Option<f32>, cx: &mut ViewContext<Self>) {
730 match self.position(cx) {
731 DockPosition::Left | DockPosition::Right => self.width = size,
732 DockPosition::Bottom => self.height = size,
733 }
734 cx.notify();
735 }
736
737 fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
738 matches!(event, AssistantPanelEvent::ZoomIn)
739 }
740
741 fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
742 matches!(event, AssistantPanelEvent::ZoomOut)
743 }
744
745 fn is_zoomed(&self, _: &WindowContext) -> bool {
746 self.zoomed
747 }
748
749 fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
750 self.zoomed = zoomed;
751 cx.notify();
752 }
753
754 fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
755 if active {
756 if self.api_key.borrow().is_none() && !self.has_read_credentials {
757 self.has_read_credentials = true;
758 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
759 Some(api_key)
760 } else if let Some((_, api_key)) = cx
761 .platform()
762 .read_credentials(OPENAI_API_URL)
763 .log_err()
764 .flatten()
765 {
766 String::from_utf8(api_key).log_err()
767 } else {
768 None
769 };
770 if let Some(api_key) = api_key {
771 *self.api_key.borrow_mut() = Some(api_key);
772 } else if self.api_key_editor.is_none() {
773 self.api_key_editor = Some(build_api_key_editor(cx));
774 cx.notify();
775 }
776 }
777
778 if self.editors.is_empty() {
779 self.new_conversation(cx);
780 }
781 }
782 }
783
784 fn icon_path(&self, cx: &WindowContext) -> Option<&'static str> {
785 settings::get::<AssistantSettings>(cx)
786 .button
787 .then(|| "icons/ai.svg")
788 }
789
790 fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
791 ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
792 }
793
794 fn should_change_position_on_event(event: &Self::Event) -> bool {
795 matches!(event, AssistantPanelEvent::DockPositionChanged)
796 }
797
798 fn should_activate_on_event(_: &Self::Event) -> bool {
799 false
800 }
801
802 fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
803 matches!(event, AssistantPanelEvent::Close)
804 }
805
806 fn has_focus(&self, _: &WindowContext) -> bool {
807 self.has_focus
808 }
809
810 fn is_focus_event(event: &Self::Event) -> bool {
811 matches!(event, AssistantPanelEvent::Focus)
812 }
813}
814
815enum ConversationEvent {
816 MessagesEdited,
817 SummaryChanged,
818 StreamedCompletion,
819}
820
821#[derive(Default)]
822struct Summary {
823 text: String,
824 done: bool,
825}
826
827struct Conversation {
828 buffer: ModelHandle<Buffer>,
829 message_anchors: Vec<MessageAnchor>,
830 messages_metadata: HashMap<MessageId, MessageMetadata>,
831 next_message_id: MessageId,
832 summary: Option<Summary>,
833 pending_summary: Task<Option<()>>,
834 completion_count: usize,
835 pending_completions: Vec<PendingCompletion>,
836 model: OpenAIModel,
837 token_count: Option<usize>,
838 max_token_count: usize,
839 pending_token_count: Task<Option<()>>,
840 api_key: Rc<RefCell<Option<String>>>,
841 pending_save: Task<Result<()>>,
842 path: Option<PathBuf>,
843 _subscriptions: Vec<Subscription>,
844}
845
846impl Entity for Conversation {
847 type Event = ConversationEvent;
848}
849
850impl Conversation {
851 fn new(
852 api_key: Rc<RefCell<Option<String>>>,
853 language_registry: Arc<LanguageRegistry>,
854 cx: &mut ModelContext<Self>,
855 ) -> Self {
856 let markdown = language_registry.language_for_name("Markdown");
857 let buffer = cx.add_model(|cx| {
858 let mut buffer = Buffer::new(0, "", cx);
859 buffer.set_language_registry(language_registry);
860 cx.spawn_weak(|buffer, mut cx| async move {
861 let markdown = markdown.await?;
862 let buffer = buffer
863 .upgrade(&cx)
864 .ok_or_else(|| anyhow!("buffer was dropped"))?;
865 buffer.update(&mut cx, |buffer, cx| {
866 buffer.set_language(Some(markdown), cx)
867 });
868 anyhow::Ok(())
869 })
870 .detach_and_log_err(cx);
871 buffer
872 });
873
874 let settings = settings::get::<AssistantSettings>(cx);
875 let model = settings.default_open_ai_model.clone();
876
877 let mut this = Self {
878 message_anchors: Default::default(),
879 messages_metadata: Default::default(),
880 next_message_id: Default::default(),
881 summary: None,
882 pending_summary: Task::ready(None),
883 completion_count: Default::default(),
884 pending_completions: Default::default(),
885 token_count: None,
886 max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
887 pending_token_count: Task::ready(None),
888 model: model.clone(),
889 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
890 pending_save: Task::ready(Ok(())),
891 path: None,
892 api_key,
893 buffer,
894 };
895 let message = MessageAnchor {
896 id: MessageId(post_inc(&mut this.next_message_id.0)),
897 start: language::Anchor::MIN,
898 };
899 this.message_anchors.push(message.clone());
900 this.messages_metadata.insert(
901 message.id,
902 MessageMetadata {
903 role: Role::User,
904 sent_at: Local::now(),
905 status: MessageStatus::Done,
906 },
907 );
908
909 this.count_remaining_tokens(cx);
910 this
911 }
912
913 fn serialize(&self, cx: &AppContext) -> SavedConversation {
914 SavedConversation {
915 zed: "conversation".into(),
916 version: SavedConversation::VERSION.into(),
917 text: self.buffer.read(cx).text(),
918 message_metadata: self.messages_metadata.clone(),
919 messages: self
920 .messages(cx)
921 .map(|message| SavedMessage {
922 id: message.id,
923 start: message.offset_range.start,
924 })
925 .collect(),
926 summary: self
927 .summary
928 .as_ref()
929 .map(|summary| summary.text.clone())
930 .unwrap_or_default(),
931 model: self.model.clone(),
932 }
933 }
934
935 fn deserialize(
936 saved_conversation: SavedConversation,
937 path: PathBuf,
938 api_key: Rc<RefCell<Option<String>>>,
939 language_registry: Arc<LanguageRegistry>,
940 cx: &mut ModelContext<Self>,
941 ) -> Self {
942 let model = saved_conversation.model;
943 let markdown = language_registry.language_for_name("Markdown");
944 let mut message_anchors = Vec::new();
945 let mut next_message_id = MessageId(0);
946 let buffer = cx.add_model(|cx| {
947 let mut buffer = Buffer::new(0, saved_conversation.text, cx);
948 for message in saved_conversation.messages {
949 message_anchors.push(MessageAnchor {
950 id: message.id,
951 start: buffer.anchor_before(message.start),
952 });
953 next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
954 }
955 buffer.set_language_registry(language_registry);
956 cx.spawn_weak(|buffer, mut cx| async move {
957 let markdown = markdown.await?;
958 let buffer = buffer
959 .upgrade(&cx)
960 .ok_or_else(|| anyhow!("buffer was dropped"))?;
961 buffer.update(&mut cx, |buffer, cx| {
962 buffer.set_language(Some(markdown), cx)
963 });
964 anyhow::Ok(())
965 })
966 .detach_and_log_err(cx);
967 buffer
968 });
969
970 let mut this = Self {
971 message_anchors,
972 messages_metadata: saved_conversation.message_metadata,
973 next_message_id,
974 summary: Some(Summary {
975 text: saved_conversation.summary,
976 done: true,
977 }),
978 pending_summary: Task::ready(None),
979 completion_count: Default::default(),
980 pending_completions: Default::default(),
981 token_count: None,
982 max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
983 pending_token_count: Task::ready(None),
984 model,
985 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
986 pending_save: Task::ready(Ok(())),
987 path: Some(path),
988 api_key,
989 buffer,
990 };
991 this.count_remaining_tokens(cx);
992 this
993 }
994
995 fn handle_buffer_event(
996 &mut self,
997 _: ModelHandle<Buffer>,
998 event: &language::Event,
999 cx: &mut ModelContext<Self>,
1000 ) {
1001 match event {
1002 language::Event::Edited => {
1003 self.count_remaining_tokens(cx);
1004 cx.emit(ConversationEvent::MessagesEdited);
1005 }
1006 _ => {}
1007 }
1008 }
1009
1010 fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
1011 let messages = self
1012 .messages(cx)
1013 .into_iter()
1014 .filter_map(|message| {
1015 Some(tiktoken_rs::ChatCompletionRequestMessage {
1016 role: match message.role {
1017 Role::User => "user".into(),
1018 Role::Assistant => "assistant".into(),
1019 Role::System => "system".into(),
1020 },
1021 content: self
1022 .buffer
1023 .read(cx)
1024 .text_for_range(message.offset_range)
1025 .collect(),
1026 name: None,
1027 })
1028 })
1029 .collect::<Vec<_>>();
1030 let model = self.model.clone();
1031 self.pending_token_count = cx.spawn_weak(|this, mut cx| {
1032 async move {
1033 cx.background().timer(Duration::from_millis(200)).await;
1034 let token_count = cx
1035 .background()
1036 .spawn(async move {
1037 tiktoken_rs::num_tokens_from_messages(&model.full_name(), &messages)
1038 })
1039 .await?;
1040
1041 this.upgrade(&cx)
1042 .ok_or_else(|| anyhow!("conversation was dropped"))?
1043 .update(&mut cx, |this, cx| {
1044 this.max_token_count =
1045 tiktoken_rs::model::get_context_size(&this.model.full_name());
1046 this.token_count = Some(token_count);
1047 cx.notify()
1048 });
1049 anyhow::Ok(())
1050 }
1051 .log_err()
1052 });
1053 }
1054
1055 fn remaining_tokens(&self) -> Option<isize> {
1056 Some(self.max_token_count as isize - self.token_count? as isize)
1057 }
1058
1059 fn set_model(&mut self, model: OpenAIModel, cx: &mut ModelContext<Self>) {
1060 self.model = model;
1061 self.count_remaining_tokens(cx);
1062 cx.notify();
1063 }
1064
1065 fn assist(
1066 &mut self,
1067 selected_messages: HashSet<MessageId>,
1068 cx: &mut ModelContext<Self>,
1069 ) -> Vec<MessageAnchor> {
1070 let mut user_messages = Vec::new();
1071 let mut tasks = Vec::new();
1072
1073 let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
1074 message
1075 .start
1076 .is_valid(self.buffer.read(cx))
1077 .then_some(message.id)
1078 });
1079
1080 for selected_message_id in selected_messages {
1081 let selected_message_role =
1082 if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
1083 metadata.role
1084 } else {
1085 continue;
1086 };
1087
1088 if selected_message_role == Role::Assistant {
1089 if let Some(user_message) = self.insert_message_after(
1090 selected_message_id,
1091 Role::User,
1092 MessageStatus::Done,
1093 cx,
1094 ) {
1095 user_messages.push(user_message);
1096 } else {
1097 continue;
1098 }
1099 } else {
1100 let request = OpenAIRequest {
1101 model: self.model.full_name().to_string(),
1102 messages: self
1103 .messages(cx)
1104 .filter(|message| matches!(message.status, MessageStatus::Done))
1105 .flat_map(|message| {
1106 let mut system_message = None;
1107 if message.id == selected_message_id {
1108 system_message = Some(RequestMessage {
1109 role: Role::System,
1110 content: concat!(
1111 "Treat the following messages as additional knowledge you have learned about, ",
1112 "but act as if they were not part of this conversation. That is, treat them ",
1113 "as if the user didn't see them and couldn't possibly inquire about them."
1114 ).into()
1115 });
1116 }
1117
1118 Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
1119 })
1120 .chain(Some(RequestMessage {
1121 role: Role::System,
1122 content: format!(
1123 "Direct your reply to message with id {}. Do not include a [Message X] header.",
1124 selected_message_id.0
1125 ),
1126 }))
1127 .collect(),
1128 stream: true,
1129 };
1130
1131 let Some(api_key) = self.api_key.borrow().clone() else { continue };
1132 let stream = stream_completion(api_key, cx.background().clone(), request);
1133 let assistant_message = self
1134 .insert_message_after(
1135 selected_message_id,
1136 Role::Assistant,
1137 MessageStatus::Pending,
1138 cx,
1139 )
1140 .unwrap();
1141
1142 // Queue up the user's next reply
1143 if Some(selected_message_id) == last_message_id {
1144 let user_message = self
1145 .insert_message_after(
1146 assistant_message.id,
1147 Role::User,
1148 MessageStatus::Done,
1149 cx,
1150 )
1151 .unwrap();
1152 user_messages.push(user_message);
1153 }
1154
1155 tasks.push(cx.spawn_weak({
1156 |this, mut cx| async move {
1157 let assistant_message_id = assistant_message.id;
1158 let stream_completion = async {
1159 let mut messages = stream.await?;
1160
1161 while let Some(message) = messages.next().await {
1162 let mut message = message?;
1163 if let Some(choice) = message.choices.pop() {
1164 this.upgrade(&cx)
1165 .ok_or_else(|| anyhow!("conversation was dropped"))?
1166 .update(&mut cx, |this, cx| {
1167 let text: Arc<str> = choice.delta.content?.into();
1168 let message_ix = this.message_anchors.iter().position(
1169 |message| message.id == assistant_message_id,
1170 )?;
1171 this.buffer.update(cx, |buffer, cx| {
1172 let offset = this.message_anchors[message_ix + 1..]
1173 .iter()
1174 .find(|message| message.start.is_valid(buffer))
1175 .map_or(buffer.len(), |message| {
1176 message
1177 .start
1178 .to_offset(buffer)
1179 .saturating_sub(1)
1180 });
1181 buffer.edit([(offset..offset, text)], None, cx);
1182 });
1183 cx.emit(ConversationEvent::StreamedCompletion);
1184
1185 Some(())
1186 });
1187 }
1188 smol::future::yield_now().await;
1189 }
1190
1191 this.upgrade(&cx)
1192 .ok_or_else(|| anyhow!("conversation was dropped"))?
1193 .update(&mut cx, |this, cx| {
1194 this.pending_completions.retain(|completion| {
1195 completion.id != this.completion_count
1196 });
1197 this.summarize(cx);
1198 });
1199
1200 anyhow::Ok(())
1201 };
1202
1203 let result = stream_completion.await;
1204 if let Some(this) = this.upgrade(&cx) {
1205 this.update(&mut cx, |this, cx| {
1206 if let Some(metadata) =
1207 this.messages_metadata.get_mut(&assistant_message.id)
1208 {
1209 match result {
1210 Ok(_) => {
1211 metadata.status = MessageStatus::Done;
1212 }
1213 Err(error) => {
1214 metadata.status = MessageStatus::Error(
1215 error.to_string().trim().into(),
1216 );
1217 }
1218 }
1219 cx.notify();
1220 }
1221 });
1222 }
1223 }
1224 }));
1225 }
1226 }
1227
1228 if !tasks.is_empty() {
1229 self.pending_completions.push(PendingCompletion {
1230 id: post_inc(&mut self.completion_count),
1231 _tasks: tasks,
1232 });
1233 }
1234
1235 user_messages
1236 }
1237
1238 fn cancel_last_assist(&mut self) -> bool {
1239 self.pending_completions.pop().is_some()
1240 }
1241
1242 fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
1243 for id in ids {
1244 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
1245 metadata.role.cycle();
1246 cx.emit(ConversationEvent::MessagesEdited);
1247 cx.notify();
1248 }
1249 }
1250 }
1251
1252 fn insert_message_after(
1253 &mut self,
1254 message_id: MessageId,
1255 role: Role,
1256 status: MessageStatus,
1257 cx: &mut ModelContext<Self>,
1258 ) -> Option<MessageAnchor> {
1259 if let Some(prev_message_ix) = self
1260 .message_anchors
1261 .iter()
1262 .position(|message| message.id == message_id)
1263 {
1264 // Find the next valid message after the one we were given.
1265 let mut next_message_ix = prev_message_ix + 1;
1266 while let Some(next_message) = self.message_anchors.get(next_message_ix) {
1267 if next_message.start.is_valid(self.buffer.read(cx)) {
1268 break;
1269 }
1270 next_message_ix += 1;
1271 }
1272
1273 let start = self.buffer.update(cx, |buffer, cx| {
1274 let offset = self
1275 .message_anchors
1276 .get(next_message_ix)
1277 .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
1278 buffer.edit([(offset..offset, "\n")], None, cx);
1279 buffer.anchor_before(offset + 1)
1280 });
1281 let message = MessageAnchor {
1282 id: MessageId(post_inc(&mut self.next_message_id.0)),
1283 start,
1284 };
1285 self.message_anchors
1286 .insert(next_message_ix, message.clone());
1287 self.messages_metadata.insert(
1288 message.id,
1289 MessageMetadata {
1290 role,
1291 sent_at: Local::now(),
1292 status,
1293 },
1294 );
1295 cx.emit(ConversationEvent::MessagesEdited);
1296 Some(message)
1297 } else {
1298 None
1299 }
1300 }
1301
1302 fn split_message(
1303 &mut self,
1304 range: Range<usize>,
1305 cx: &mut ModelContext<Self>,
1306 ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
1307 let start_message = self.message_for_offset(range.start, cx);
1308 let end_message = self.message_for_offset(range.end, cx);
1309 if let Some((start_message, end_message)) = start_message.zip(end_message) {
1310 // Prevent splitting when range spans multiple messages.
1311 if start_message.id != end_message.id {
1312 return (None, None);
1313 }
1314
1315 let message = start_message;
1316 let role = message.role;
1317 let mut edited_buffer = false;
1318
1319 let mut suffix_start = None;
1320 if range.start > message.offset_range.start && range.end < message.offset_range.end - 1
1321 {
1322 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
1323 suffix_start = Some(range.end + 1);
1324 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
1325 suffix_start = Some(range.end);
1326 }
1327 }
1328
1329 let suffix = if let Some(suffix_start) = suffix_start {
1330 MessageAnchor {
1331 id: MessageId(post_inc(&mut self.next_message_id.0)),
1332 start: self.buffer.read(cx).anchor_before(suffix_start),
1333 }
1334 } else {
1335 self.buffer.update(cx, |buffer, cx| {
1336 buffer.edit([(range.end..range.end, "\n")], None, cx);
1337 });
1338 edited_buffer = true;
1339 MessageAnchor {
1340 id: MessageId(post_inc(&mut self.next_message_id.0)),
1341 start: self.buffer.read(cx).anchor_before(range.end + 1),
1342 }
1343 };
1344
1345 self.message_anchors
1346 .insert(message.index_range.end + 1, suffix.clone());
1347 self.messages_metadata.insert(
1348 suffix.id,
1349 MessageMetadata {
1350 role,
1351 sent_at: Local::now(),
1352 status: MessageStatus::Done,
1353 },
1354 );
1355
1356 let new_messages =
1357 if range.start == range.end || range.start == message.offset_range.start {
1358 (None, Some(suffix))
1359 } else {
1360 let mut prefix_end = None;
1361 if range.start > message.offset_range.start
1362 && range.end < message.offset_range.end - 1
1363 {
1364 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
1365 prefix_end = Some(range.start + 1);
1366 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
1367 == Some('\n')
1368 {
1369 prefix_end = Some(range.start);
1370 }
1371 }
1372
1373 let selection = if let Some(prefix_end) = prefix_end {
1374 cx.emit(ConversationEvent::MessagesEdited);
1375 MessageAnchor {
1376 id: MessageId(post_inc(&mut self.next_message_id.0)),
1377 start: self.buffer.read(cx).anchor_before(prefix_end),
1378 }
1379 } else {
1380 self.buffer.update(cx, |buffer, cx| {
1381 buffer.edit([(range.start..range.start, "\n")], None, cx)
1382 });
1383 edited_buffer = true;
1384 MessageAnchor {
1385 id: MessageId(post_inc(&mut self.next_message_id.0)),
1386 start: self.buffer.read(cx).anchor_before(range.end + 1),
1387 }
1388 };
1389
1390 self.message_anchors
1391 .insert(message.index_range.end + 1, selection.clone());
1392 self.messages_metadata.insert(
1393 selection.id,
1394 MessageMetadata {
1395 role,
1396 sent_at: Local::now(),
1397 status: MessageStatus::Done,
1398 },
1399 );
1400 (Some(selection), Some(suffix))
1401 };
1402
1403 if !edited_buffer {
1404 cx.emit(ConversationEvent::MessagesEdited);
1405 }
1406 new_messages
1407 } else {
1408 (None, None)
1409 }
1410 }
1411
1412 fn summarize(&mut self, cx: &mut ModelContext<Self>) {
1413 if self.message_anchors.len() >= 2 && self.summary.is_none() {
1414 let api_key = self.api_key.borrow().clone();
1415 if let Some(api_key) = api_key {
1416 let messages = self
1417 .messages(cx)
1418 .take(2)
1419 .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
1420 .chain(Some(RequestMessage {
1421 role: Role::User,
1422 content:
1423 "Summarize the conversation into a short title without punctuation"
1424 .into(),
1425 }));
1426 let request = OpenAIRequest {
1427 model: self.model.full_name().to_string(),
1428 messages: messages.collect(),
1429 stream: true,
1430 };
1431
1432 let stream = stream_completion(api_key, cx.background().clone(), request);
1433 self.pending_summary = cx.spawn(|this, mut cx| {
1434 async move {
1435 let mut messages = stream.await?;
1436
1437 while let Some(message) = messages.next().await {
1438 let mut message = message?;
1439 if let Some(choice) = message.choices.pop() {
1440 let text = choice.delta.content.unwrap_or_default();
1441 this.update(&mut cx, |this, cx| {
1442 this.summary
1443 .get_or_insert(Default::default())
1444 .text
1445 .push_str(&text);
1446 cx.emit(ConversationEvent::SummaryChanged);
1447 });
1448 }
1449 }
1450
1451 this.update(&mut cx, |this, cx| {
1452 if let Some(summary) = this.summary.as_mut() {
1453 summary.done = true;
1454 cx.emit(ConversationEvent::SummaryChanged);
1455 }
1456 });
1457
1458 anyhow::Ok(())
1459 }
1460 .log_err()
1461 });
1462 }
1463 }
1464 }
1465
1466 fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
1467 self.messages_for_offsets([offset], cx).pop()
1468 }
1469
1470 fn messages_for_offsets(
1471 &self,
1472 offsets: impl IntoIterator<Item = usize>,
1473 cx: &AppContext,
1474 ) -> Vec<Message> {
1475 let mut result = Vec::new();
1476
1477 let mut messages = self.messages(cx).peekable();
1478 let mut offsets = offsets.into_iter().peekable();
1479 let mut current_message = messages.next();
1480 while let Some(offset) = offsets.next() {
1481 // Locate the message that contains the offset.
1482 while current_message.as_ref().map_or(false, |message| {
1483 !message.offset_range.contains(&offset) && messages.peek().is_some()
1484 }) {
1485 current_message = messages.next();
1486 }
1487 let Some(message) = current_message.as_ref() else { break };
1488
1489 // Skip offsets that are in the same message.
1490 while offsets.peek().map_or(false, |offset| {
1491 message.offset_range.contains(offset) || messages.peek().is_none()
1492 }) {
1493 offsets.next();
1494 }
1495
1496 result.push(message.clone());
1497 }
1498 result
1499 }
1500
1501 fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
1502 let buffer = self.buffer.read(cx);
1503 let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
1504 iter::from_fn(move || {
1505 while let Some((start_ix, message_anchor)) = message_anchors.next() {
1506 let metadata = self.messages_metadata.get(&message_anchor.id)?;
1507 let message_start = message_anchor.start.to_offset(buffer);
1508 let mut message_end = None;
1509 let mut end_ix = start_ix;
1510 while let Some((_, next_message)) = message_anchors.peek() {
1511 if next_message.start.is_valid(buffer) {
1512 message_end = Some(next_message.start);
1513 break;
1514 } else {
1515 end_ix += 1;
1516 message_anchors.next();
1517 }
1518 }
1519 let message_end = message_end
1520 .unwrap_or(language::Anchor::MAX)
1521 .to_offset(buffer);
1522 return Some(Message {
1523 index_range: start_ix..end_ix,
1524 offset_range: message_start..message_end,
1525 id: message_anchor.id,
1526 anchor: message_anchor.start,
1527 role: metadata.role,
1528 sent_at: metadata.sent_at,
1529 status: metadata.status.clone(),
1530 });
1531 }
1532 None
1533 })
1534 }
1535
1536 fn save(
1537 &mut self,
1538 debounce: Option<Duration>,
1539 fs: Arc<dyn Fs>,
1540 cx: &mut ModelContext<Conversation>,
1541 ) {
1542 self.pending_save = cx.spawn(|this, mut cx| async move {
1543 if let Some(debounce) = debounce {
1544 cx.background().timer(debounce).await;
1545 }
1546
1547 let (old_path, summary) = this.read_with(&cx, |this, _| {
1548 let path = this.path.clone();
1549 let summary = if let Some(summary) = this.summary.as_ref() {
1550 if summary.done {
1551 Some(summary.text.clone())
1552 } else {
1553 None
1554 }
1555 } else {
1556 None
1557 };
1558 (path, summary)
1559 });
1560
1561 if let Some(summary) = summary {
1562 let conversation = this.read_with(&cx, |this, cx| this.serialize(cx));
1563 let path = if let Some(old_path) = old_path {
1564 old_path
1565 } else {
1566 let mut discriminant = 1;
1567 let mut new_path;
1568 loop {
1569 new_path = CONVERSATIONS_DIR.join(&format!(
1570 "{} - {}.zed.json",
1571 summary.trim(),
1572 discriminant
1573 ));
1574 if fs.is_file(&new_path).await {
1575 discriminant += 1;
1576 } else {
1577 break;
1578 }
1579 }
1580 new_path
1581 };
1582
1583 fs.create_dir(CONVERSATIONS_DIR.as_ref()).await?;
1584 fs.atomic_write(path.clone(), serde_json::to_string(&conversation).unwrap())
1585 .await?;
1586 this.update(&mut cx, |this, _| this.path = Some(path));
1587 }
1588
1589 Ok(())
1590 });
1591 }
1592}
1593
1594struct PendingCompletion {
1595 id: usize,
1596 _tasks: Vec<Task<()>>,
1597}
1598
1599enum ConversationEditorEvent {
1600 TabContentChanged,
1601}
1602
1603#[derive(Copy, Clone, Debug, PartialEq)]
1604struct ScrollPosition {
1605 offset_before_cursor: Vector2F,
1606 cursor: Anchor,
1607}
1608
1609struct ConversationEditor {
1610 conversation: ModelHandle<Conversation>,
1611 fs: Arc<dyn Fs>,
1612 editor: ViewHandle<Editor>,
1613 blocks: HashSet<BlockId>,
1614 scroll_position: Option<ScrollPosition>,
1615 _subscriptions: Vec<Subscription>,
1616}
1617
1618impl ConversationEditor {
1619 fn new(
1620 api_key: Rc<RefCell<Option<String>>>,
1621 language_registry: Arc<LanguageRegistry>,
1622 fs: Arc<dyn Fs>,
1623 cx: &mut ViewContext<Self>,
1624 ) -> Self {
1625 let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
1626 Self::for_conversation(conversation, fs, cx)
1627 }
1628
1629 fn for_conversation(
1630 conversation: ModelHandle<Conversation>,
1631 fs: Arc<dyn Fs>,
1632 cx: &mut ViewContext<Self>,
1633 ) -> Self {
1634 let editor = cx.add_view(|cx| {
1635 let mut editor = Editor::for_buffer(conversation.read(cx).buffer.clone(), None, cx);
1636 editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
1637 editor.set_show_gutter(false, cx);
1638 editor.set_show_wrap_guides(false, cx);
1639 editor
1640 });
1641
1642 let _subscriptions = vec![
1643 cx.observe(&conversation, |_, _, cx| cx.notify()),
1644 cx.subscribe(&conversation, Self::handle_conversation_event),
1645 cx.subscribe(&editor, Self::handle_editor_event),
1646 ];
1647
1648 let mut this = Self {
1649 conversation,
1650 editor,
1651 blocks: Default::default(),
1652 scroll_position: None,
1653 fs,
1654 _subscriptions,
1655 };
1656 this.update_message_headers(cx);
1657 this
1658 }
1659
1660 fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
1661 let cursors = self.cursors(cx);
1662
1663 let user_messages = self.conversation.update(cx, |conversation, cx| {
1664 let selected_messages = conversation
1665 .messages_for_offsets(cursors, cx)
1666 .into_iter()
1667 .map(|message| message.id)
1668 .collect();
1669 conversation.assist(selected_messages, cx)
1670 });
1671 let new_selections = user_messages
1672 .iter()
1673 .map(|message| {
1674 let cursor = message
1675 .start
1676 .to_offset(self.conversation.read(cx).buffer.read(cx));
1677 cursor..cursor
1678 })
1679 .collect::<Vec<_>>();
1680 if !new_selections.is_empty() {
1681 self.editor.update(cx, |editor, cx| {
1682 editor.change_selections(
1683 Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
1684 cx,
1685 |selections| selections.select_ranges(new_selections),
1686 );
1687 });
1688 // Avoid scrolling to the new cursor position so the assistant's output is stable.
1689 cx.defer(|this, _| this.scroll_position = None);
1690 }
1691 }
1692
1693 fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
1694 if !self
1695 .conversation
1696 .update(cx, |conversation, _| conversation.cancel_last_assist())
1697 {
1698 cx.propagate_action();
1699 }
1700 }
1701
1702 fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
1703 let cursors = self.cursors(cx);
1704 self.conversation.update(cx, |conversation, cx| {
1705 let messages = conversation
1706 .messages_for_offsets(cursors, cx)
1707 .into_iter()
1708 .map(|message| message.id)
1709 .collect();
1710 conversation.cycle_message_roles(messages, cx)
1711 });
1712 }
1713
1714 fn cursors(&self, cx: &AppContext) -> Vec<usize> {
1715 let selections = self.editor.read(cx).selections.all::<usize>(cx);
1716 selections
1717 .into_iter()
1718 .map(|selection| selection.head())
1719 .collect()
1720 }
1721
1722 fn handle_conversation_event(
1723 &mut self,
1724 _: ModelHandle<Conversation>,
1725 event: &ConversationEvent,
1726 cx: &mut ViewContext<Self>,
1727 ) {
1728 match event {
1729 ConversationEvent::MessagesEdited => {
1730 self.update_message_headers(cx);
1731 self.conversation.update(cx, |conversation, cx| {
1732 conversation.save(Some(Duration::from_millis(500)), self.fs.clone(), cx);
1733 });
1734 }
1735 ConversationEvent::SummaryChanged => {
1736 cx.emit(ConversationEditorEvent::TabContentChanged);
1737 self.conversation.update(cx, |conversation, cx| {
1738 conversation.save(None, self.fs.clone(), cx);
1739 });
1740 }
1741 ConversationEvent::StreamedCompletion => {
1742 self.editor.update(cx, |editor, cx| {
1743 if let Some(scroll_position) = self.scroll_position {
1744 let snapshot = editor.snapshot(cx);
1745 let cursor_point = scroll_position.cursor.to_display_point(&snapshot);
1746 let scroll_top =
1747 cursor_point.row() as f32 - scroll_position.offset_before_cursor.y();
1748 editor.set_scroll_position(
1749 vec2f(scroll_position.offset_before_cursor.x(), scroll_top),
1750 cx,
1751 );
1752 }
1753 });
1754 }
1755 }
1756 }
1757
1758 fn handle_editor_event(
1759 &mut self,
1760 _: ViewHandle<Editor>,
1761 event: &editor::Event,
1762 cx: &mut ViewContext<Self>,
1763 ) {
1764 match event {
1765 editor::Event::ScrollPositionChanged { autoscroll, .. } => {
1766 let cursor_scroll_position = self.cursor_scroll_position(cx);
1767 if *autoscroll {
1768 self.scroll_position = cursor_scroll_position;
1769 } else if self.scroll_position != cursor_scroll_position {
1770 self.scroll_position = None;
1771 }
1772 }
1773 editor::Event::SelectionsChanged { .. } => {
1774 self.scroll_position = self.cursor_scroll_position(cx);
1775 }
1776 _ => {}
1777 }
1778 }
1779
1780 fn cursor_scroll_position(&self, cx: &mut ViewContext<Self>) -> Option<ScrollPosition> {
1781 self.editor.update(cx, |editor, cx| {
1782 let snapshot = editor.snapshot(cx);
1783 let cursor = editor.selections.newest_anchor().head();
1784 let cursor_row = cursor.to_display_point(&snapshot.display_snapshot).row() as f32;
1785 let scroll_position = editor
1786 .scroll_manager
1787 .anchor()
1788 .scroll_position(&snapshot.display_snapshot);
1789
1790 let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1791 if (scroll_position.y()..scroll_bottom).contains(&cursor_row) {
1792 Some(ScrollPosition {
1793 cursor,
1794 offset_before_cursor: vec2f(
1795 scroll_position.x(),
1796 cursor_row - scroll_position.y(),
1797 ),
1798 })
1799 } else {
1800 None
1801 }
1802 })
1803 }
1804
1805 fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
1806 self.editor.update(cx, |editor, cx| {
1807 let buffer = editor.buffer().read(cx).snapshot(cx);
1808 let excerpt_id = *buffer.as_singleton().unwrap().0;
1809 let old_blocks = std::mem::take(&mut self.blocks);
1810 let new_blocks = self
1811 .conversation
1812 .read(cx)
1813 .messages(cx)
1814 .map(|message| BlockProperties {
1815 position: buffer.anchor_in_excerpt(excerpt_id, message.anchor),
1816 height: 2,
1817 style: BlockStyle::Sticky,
1818 render: Arc::new({
1819 let conversation = self.conversation.clone();
1820 // let metadata = message.metadata.clone();
1821 // let message = message.clone();
1822 move |cx| {
1823 enum Sender {}
1824 enum ErrorTooltip {}
1825
1826 let theme = theme::current(cx);
1827 let style = &theme.assistant;
1828 let message_id = message.id;
1829 let sender = MouseEventHandler::new::<Sender, _>(
1830 message_id.0,
1831 cx,
1832 |state, _| match message.role {
1833 Role::User => {
1834 let style = style.user_sender.style_for(state);
1835 Label::new("You", style.text.clone())
1836 .contained()
1837 .with_style(style.container)
1838 }
1839 Role::Assistant => {
1840 let style = style.assistant_sender.style_for(state);
1841 Label::new("Assistant", style.text.clone())
1842 .contained()
1843 .with_style(style.container)
1844 }
1845 Role::System => {
1846 let style = style.system_sender.style_for(state);
1847 Label::new("System", style.text.clone())
1848 .contained()
1849 .with_style(style.container)
1850 }
1851 },
1852 )
1853 .with_cursor_style(CursorStyle::PointingHand)
1854 .on_down(MouseButton::Left, {
1855 let conversation = conversation.clone();
1856 move |_, _, cx| {
1857 conversation.update(cx, |conversation, cx| {
1858 conversation.cycle_message_roles(
1859 HashSet::from_iter(Some(message_id)),
1860 cx,
1861 )
1862 })
1863 }
1864 });
1865
1866 Flex::row()
1867 .with_child(sender.aligned())
1868 .with_child(
1869 Label::new(
1870 message.sent_at.format("%I:%M%P").to_string(),
1871 style.sent_at.text.clone(),
1872 )
1873 .contained()
1874 .with_style(style.sent_at.container)
1875 .aligned(),
1876 )
1877 .with_children(
1878 if let MessageStatus::Error(error) = &message.status {
1879 Some(
1880 Svg::new("icons/circle_x_mark_12.svg")
1881 .with_color(style.error_icon.color)
1882 .constrained()
1883 .with_width(style.error_icon.width)
1884 .contained()
1885 .with_style(style.error_icon.container)
1886 .with_tooltip::<ErrorTooltip>(
1887 message_id.0,
1888 error.to_string(),
1889 None,
1890 theme.tooltip.clone(),
1891 cx,
1892 )
1893 .aligned(),
1894 )
1895 } else {
1896 None
1897 },
1898 )
1899 .aligned()
1900 .left()
1901 .contained()
1902 .with_style(style.message_header)
1903 .into_any()
1904 }
1905 }),
1906 disposition: BlockDisposition::Above,
1907 })
1908 .collect::<Vec<_>>();
1909
1910 editor.remove_blocks(old_blocks, None, cx);
1911 let ids = editor.insert_blocks(new_blocks, None, cx);
1912 self.blocks = HashSet::from_iter(ids);
1913 });
1914 }
1915
1916 fn quote_selection(
1917 workspace: &mut Workspace,
1918 _: &QuoteSelection,
1919 cx: &mut ViewContext<Workspace>,
1920 ) {
1921 let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1922 return;
1923 };
1924 let Some(editor) = workspace.active_item(cx).and_then(|item| item.act_as::<Editor>(cx)) else {
1925 return;
1926 };
1927
1928 let text = editor.read_with(cx, |editor, cx| {
1929 let range = editor.selections.newest::<usize>(cx).range();
1930 let buffer = editor.buffer().read(cx).snapshot(cx);
1931 let start_language = buffer.language_at(range.start);
1932 let end_language = buffer.language_at(range.end);
1933 let language_name = if start_language == end_language {
1934 start_language.map(|language| language.name())
1935 } else {
1936 None
1937 };
1938 let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1939
1940 let selected_text = buffer.text_for_range(range).collect::<String>();
1941 if selected_text.is_empty() {
1942 None
1943 } else {
1944 Some(if language_name == "markdown" {
1945 selected_text
1946 .lines()
1947 .map(|line| format!("> {}", line))
1948 .collect::<Vec<_>>()
1949 .join("\n")
1950 } else {
1951 format!("```{language_name}\n{selected_text}\n```")
1952 })
1953 }
1954 });
1955
1956 // Activate the panel
1957 if !panel.read(cx).has_focus(cx) {
1958 workspace.toggle_panel_focus::<AssistantPanel>(cx);
1959 }
1960
1961 if let Some(text) = text {
1962 panel.update(cx, |panel, cx| {
1963 let conversation = panel
1964 .active_editor()
1965 .cloned()
1966 .unwrap_or_else(|| panel.new_conversation(cx));
1967 conversation.update(cx, |conversation, cx| {
1968 conversation
1969 .editor
1970 .update(cx, |editor, cx| editor.insert(&text, cx))
1971 });
1972 });
1973 }
1974 }
1975
1976 fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1977 let editor = self.editor.read(cx);
1978 let conversation = self.conversation.read(cx);
1979 if editor.selections.count() == 1 {
1980 let selection = editor.selections.newest::<usize>(cx);
1981 let mut copied_text = String::new();
1982 let mut spanned_messages = 0;
1983 for message in conversation.messages(cx) {
1984 if message.offset_range.start >= selection.range().end {
1985 break;
1986 } else if message.offset_range.end >= selection.range().start {
1987 let range = cmp::max(message.offset_range.start, selection.range().start)
1988 ..cmp::min(message.offset_range.end, selection.range().end);
1989 if !range.is_empty() {
1990 spanned_messages += 1;
1991 write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
1992 for chunk in conversation.buffer.read(cx).text_for_range(range) {
1993 copied_text.push_str(&chunk);
1994 }
1995 copied_text.push('\n');
1996 }
1997 }
1998 }
1999
2000 if spanned_messages > 1 {
2001 cx.platform()
2002 .write_to_clipboard(ClipboardItem::new(copied_text));
2003 return;
2004 }
2005 }
2006
2007 cx.propagate_action();
2008 }
2009
2010 fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
2011 self.conversation.update(cx, |conversation, cx| {
2012 let selections = self.editor.read(cx).selections.disjoint_anchors();
2013 for selection in selections.into_iter() {
2014 let buffer = self.editor.read(cx).buffer().read(cx).snapshot(cx);
2015 let range = selection
2016 .map(|endpoint| endpoint.to_offset(&buffer))
2017 .range();
2018 conversation.split_message(range, cx);
2019 }
2020 });
2021 }
2022
2023 fn save(&mut self, _: &Save, cx: &mut ViewContext<Self>) {
2024 self.conversation.update(cx, |conversation, cx| {
2025 conversation.save(None, self.fs.clone(), cx)
2026 });
2027 }
2028
2029 fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
2030 self.conversation.update(cx, |conversation, cx| {
2031 let new_model = conversation.model.cycle();
2032 conversation.set_model(new_model, cx);
2033 });
2034 }
2035
2036 fn title(&self, cx: &AppContext) -> String {
2037 self.conversation
2038 .read(cx)
2039 .summary
2040 .as_ref()
2041 .map(|summary| summary.text.clone())
2042 .unwrap_or_else(|| "New Conversation".into())
2043 }
2044
2045 fn render_current_model(
2046 &self,
2047 style: &AssistantStyle,
2048 cx: &mut ViewContext<Self>,
2049 ) -> impl Element<Self> {
2050 enum Model {}
2051
2052 MouseEventHandler::new::<Model, _>(0, cx, |state, cx| {
2053 let style = style.model.style_for(state);
2054 let model_display_name = self.conversation.read(cx).model.short_name();
2055 Label::new(model_display_name, style.text.clone())
2056 .contained()
2057 .with_style(style.container)
2058 })
2059 .with_cursor_style(CursorStyle::PointingHand)
2060 .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx))
2061 }
2062
2063 fn render_remaining_tokens(
2064 &self,
2065 style: &AssistantStyle,
2066 cx: &mut ViewContext<Self>,
2067 ) -> Option<impl Element<Self>> {
2068 let remaining_tokens = self.conversation.read(cx).remaining_tokens()?;
2069 let remaining_tokens_style = if remaining_tokens <= 0 {
2070 &style.no_remaining_tokens
2071 } else if remaining_tokens <= 500 {
2072 &style.low_remaining_tokens
2073 } else {
2074 &style.remaining_tokens
2075 };
2076 Some(
2077 Label::new(
2078 remaining_tokens.to_string(),
2079 remaining_tokens_style.text.clone(),
2080 )
2081 .contained()
2082 .with_style(remaining_tokens_style.container),
2083 )
2084 }
2085}
2086
2087impl Entity for ConversationEditor {
2088 type Event = ConversationEditorEvent;
2089}
2090
2091impl View for ConversationEditor {
2092 fn ui_name() -> &'static str {
2093 "ConversationEditor"
2094 }
2095
2096 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
2097 let theme = &theme::current(cx).assistant;
2098 Stack::new()
2099 .with_child(
2100 ChildView::new(&self.editor, cx)
2101 .contained()
2102 .with_style(theme.container),
2103 )
2104 .with_child(
2105 Flex::row()
2106 .with_child(self.render_current_model(theme, cx))
2107 .with_children(self.render_remaining_tokens(theme, cx))
2108 .aligned()
2109 .top()
2110 .right(),
2111 )
2112 .into_any()
2113 }
2114
2115 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
2116 if cx.is_self_focused() {
2117 cx.focus(&self.editor);
2118 }
2119 }
2120}
2121
2122#[derive(Clone, Debug)]
2123struct MessageAnchor {
2124 id: MessageId,
2125 start: language::Anchor,
2126}
2127
2128#[derive(Clone, Debug)]
2129pub struct Message {
2130 offset_range: Range<usize>,
2131 index_range: Range<usize>,
2132 id: MessageId,
2133 anchor: language::Anchor,
2134 role: Role,
2135 sent_at: DateTime<Local>,
2136 status: MessageStatus,
2137}
2138
2139impl Message {
2140 fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
2141 let mut content = format!("[Message {}]\n", self.id.0).to_string();
2142 content.extend(buffer.text_for_range(self.offset_range.clone()));
2143 RequestMessage {
2144 role: self.role,
2145 content: content.trim_end().into(),
2146 }
2147 }
2148}
2149
2150async fn stream_completion(
2151 api_key: String,
2152 executor: Arc<Background>,
2153 mut request: OpenAIRequest,
2154) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
2155 request.stream = true;
2156
2157 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
2158
2159 let json_data = serde_json::to_string(&request)?;
2160 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
2161 .header("Content-Type", "application/json")
2162 .header("Authorization", format!("Bearer {}", api_key))
2163 .body(json_data)?
2164 .send_async()
2165 .await?;
2166
2167 let status = response.status();
2168 if status == StatusCode::OK {
2169 executor
2170 .spawn(async move {
2171 let mut lines = BufReader::new(response.body_mut()).lines();
2172
2173 fn parse_line(
2174 line: Result<String, io::Error>,
2175 ) -> Result<Option<OpenAIResponseStreamEvent>> {
2176 if let Some(data) = line?.strip_prefix("data: ") {
2177 let event = serde_json::from_str(&data)?;
2178 Ok(Some(event))
2179 } else {
2180 Ok(None)
2181 }
2182 }
2183
2184 while let Some(line) = lines.next().await {
2185 if let Some(event) = parse_line(line).transpose() {
2186 let done = event.as_ref().map_or(false, |event| {
2187 event
2188 .choices
2189 .last()
2190 .map_or(false, |choice| choice.finish_reason.is_some())
2191 });
2192 if tx.unbounded_send(event).is_err() {
2193 break;
2194 }
2195
2196 if done {
2197 break;
2198 }
2199 }
2200 }
2201
2202 anyhow::Ok(())
2203 })
2204 .detach();
2205
2206 Ok(rx)
2207 } else {
2208 let mut body = String::new();
2209 response.body_mut().read_to_string(&mut body).await?;
2210
2211 #[derive(Deserialize)]
2212 struct OpenAIResponse {
2213 error: OpenAIError,
2214 }
2215
2216 #[derive(Deserialize)]
2217 struct OpenAIError {
2218 message: String,
2219 }
2220
2221 match serde_json::from_str::<OpenAIResponse>(&body) {
2222 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
2223 "Failed to connect to OpenAI API: {}",
2224 response.error.message,
2225 )),
2226
2227 _ => Err(anyhow!(
2228 "Failed to connect to OpenAI API: {} {}",
2229 response.status(),
2230 body,
2231 )),
2232 }
2233 }
2234}
2235
2236#[cfg(test)]
2237mod tests {
2238 use super::*;
2239 use crate::MessageId;
2240 use gpui::AppContext;
2241
2242 #[gpui::test]
2243 fn test_inserting_and_removing_messages(cx: &mut AppContext) {
2244 cx.set_global(SettingsStore::test(cx));
2245 init(cx);
2246 let registry = Arc::new(LanguageRegistry::test());
2247 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2248 let buffer = conversation.read(cx).buffer.clone();
2249
2250 let message_1 = conversation.read(cx).message_anchors[0].clone();
2251 assert_eq!(
2252 messages(&conversation, cx),
2253 vec![(message_1.id, Role::User, 0..0)]
2254 );
2255
2256 let message_2 = conversation.update(cx, |conversation, cx| {
2257 conversation
2258 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
2259 .unwrap()
2260 });
2261 assert_eq!(
2262 messages(&conversation, cx),
2263 vec![
2264 (message_1.id, Role::User, 0..1),
2265 (message_2.id, Role::Assistant, 1..1)
2266 ]
2267 );
2268
2269 buffer.update(cx, |buffer, cx| {
2270 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
2271 });
2272 assert_eq!(
2273 messages(&conversation, cx),
2274 vec![
2275 (message_1.id, Role::User, 0..2),
2276 (message_2.id, Role::Assistant, 2..3)
2277 ]
2278 );
2279
2280 let message_3 = conversation.update(cx, |conversation, cx| {
2281 conversation
2282 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2283 .unwrap()
2284 });
2285 assert_eq!(
2286 messages(&conversation, cx),
2287 vec![
2288 (message_1.id, Role::User, 0..2),
2289 (message_2.id, Role::Assistant, 2..4),
2290 (message_3.id, Role::User, 4..4)
2291 ]
2292 );
2293
2294 let message_4 = conversation.update(cx, |conversation, cx| {
2295 conversation
2296 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2297 .unwrap()
2298 });
2299 assert_eq!(
2300 messages(&conversation, cx),
2301 vec![
2302 (message_1.id, Role::User, 0..2),
2303 (message_2.id, Role::Assistant, 2..4),
2304 (message_4.id, Role::User, 4..5),
2305 (message_3.id, Role::User, 5..5),
2306 ]
2307 );
2308
2309 buffer.update(cx, |buffer, cx| {
2310 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
2311 });
2312 assert_eq!(
2313 messages(&conversation, cx),
2314 vec![
2315 (message_1.id, Role::User, 0..2),
2316 (message_2.id, Role::Assistant, 2..4),
2317 (message_4.id, Role::User, 4..6),
2318 (message_3.id, Role::User, 6..7),
2319 ]
2320 );
2321
2322 // Deleting across message boundaries merges the messages.
2323 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
2324 assert_eq!(
2325 messages(&conversation, cx),
2326 vec![
2327 (message_1.id, Role::User, 0..3),
2328 (message_3.id, Role::User, 3..4),
2329 ]
2330 );
2331
2332 // Undoing the deletion should also undo the merge.
2333 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2334 assert_eq!(
2335 messages(&conversation, cx),
2336 vec![
2337 (message_1.id, Role::User, 0..2),
2338 (message_2.id, Role::Assistant, 2..4),
2339 (message_4.id, Role::User, 4..6),
2340 (message_3.id, Role::User, 6..7),
2341 ]
2342 );
2343
2344 // Redoing the deletion should also redo the merge.
2345 buffer.update(cx, |buffer, cx| buffer.redo(cx));
2346 assert_eq!(
2347 messages(&conversation, cx),
2348 vec![
2349 (message_1.id, Role::User, 0..3),
2350 (message_3.id, Role::User, 3..4),
2351 ]
2352 );
2353
2354 // Ensure we can still insert after a merged message.
2355 let message_5 = conversation.update(cx, |conversation, cx| {
2356 conversation
2357 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2358 .unwrap()
2359 });
2360 assert_eq!(
2361 messages(&conversation, cx),
2362 vec![
2363 (message_1.id, Role::User, 0..3),
2364 (message_5.id, Role::System, 3..4),
2365 (message_3.id, Role::User, 4..5)
2366 ]
2367 );
2368 }
2369
2370 #[gpui::test]
2371 fn test_message_splitting(cx: &mut AppContext) {
2372 cx.set_global(SettingsStore::test(cx));
2373 init(cx);
2374 let registry = Arc::new(LanguageRegistry::test());
2375 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2376 let buffer = conversation.read(cx).buffer.clone();
2377
2378 let message_1 = conversation.read(cx).message_anchors[0].clone();
2379 assert_eq!(
2380 messages(&conversation, cx),
2381 vec![(message_1.id, Role::User, 0..0)]
2382 );
2383
2384 buffer.update(cx, |buffer, cx| {
2385 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
2386 });
2387
2388 let (_, message_2) =
2389 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2390 let message_2 = message_2.unwrap();
2391
2392 // We recycle newlines in the middle of a split message
2393 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
2394 assert_eq!(
2395 messages(&conversation, cx),
2396 vec![
2397 (message_1.id, Role::User, 0..4),
2398 (message_2.id, Role::User, 4..16),
2399 ]
2400 );
2401
2402 let (_, message_3) =
2403 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2404 let message_3 = message_3.unwrap();
2405
2406 // We don't recycle newlines at the end of a split message
2407 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2408 assert_eq!(
2409 messages(&conversation, cx),
2410 vec![
2411 (message_1.id, Role::User, 0..4),
2412 (message_3.id, Role::User, 4..5),
2413 (message_2.id, Role::User, 5..17),
2414 ]
2415 );
2416
2417 let (_, message_4) =
2418 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2419 let message_4 = message_4.unwrap();
2420 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2421 assert_eq!(
2422 messages(&conversation, cx),
2423 vec![
2424 (message_1.id, Role::User, 0..4),
2425 (message_3.id, Role::User, 4..5),
2426 (message_2.id, Role::User, 5..9),
2427 (message_4.id, Role::User, 9..17),
2428 ]
2429 );
2430
2431 let (_, message_5) =
2432 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2433 let message_5 = message_5.unwrap();
2434 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
2435 assert_eq!(
2436 messages(&conversation, cx),
2437 vec![
2438 (message_1.id, Role::User, 0..4),
2439 (message_3.id, Role::User, 4..5),
2440 (message_2.id, Role::User, 5..9),
2441 (message_4.id, Role::User, 9..10),
2442 (message_5.id, Role::User, 10..18),
2443 ]
2444 );
2445
2446 let (message_6, message_7) = conversation.update(cx, |conversation, cx| {
2447 conversation.split_message(14..16, cx)
2448 });
2449 let message_6 = message_6.unwrap();
2450 let message_7 = message_7.unwrap();
2451 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
2452 assert_eq!(
2453 messages(&conversation, cx),
2454 vec![
2455 (message_1.id, Role::User, 0..4),
2456 (message_3.id, Role::User, 4..5),
2457 (message_2.id, Role::User, 5..9),
2458 (message_4.id, Role::User, 9..10),
2459 (message_5.id, Role::User, 10..14),
2460 (message_6.id, Role::User, 14..17),
2461 (message_7.id, Role::User, 17..19),
2462 ]
2463 );
2464 }
2465
2466 #[gpui::test]
2467 fn test_messages_for_offsets(cx: &mut AppContext) {
2468 cx.set_global(SettingsStore::test(cx));
2469 init(cx);
2470 let registry = Arc::new(LanguageRegistry::test());
2471 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2472 let buffer = conversation.read(cx).buffer.clone();
2473
2474 let message_1 = conversation.read(cx).message_anchors[0].clone();
2475 assert_eq!(
2476 messages(&conversation, cx),
2477 vec![(message_1.id, Role::User, 0..0)]
2478 );
2479
2480 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
2481 let message_2 = conversation
2482 .update(cx, |conversation, cx| {
2483 conversation.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
2484 })
2485 .unwrap();
2486 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
2487
2488 let message_3 = conversation
2489 .update(cx, |conversation, cx| {
2490 conversation.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2491 })
2492 .unwrap();
2493 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
2494
2495 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
2496 assert_eq!(
2497 messages(&conversation, cx),
2498 vec![
2499 (message_1.id, Role::User, 0..4),
2500 (message_2.id, Role::User, 4..8),
2501 (message_3.id, Role::User, 8..11)
2502 ]
2503 );
2504
2505 assert_eq!(
2506 message_ids_for_offsets(&conversation, &[0, 4, 9], cx),
2507 [message_1.id, message_2.id, message_3.id]
2508 );
2509 assert_eq!(
2510 message_ids_for_offsets(&conversation, &[0, 1, 11], cx),
2511 [message_1.id, message_3.id]
2512 );
2513
2514 let message_4 = conversation
2515 .update(cx, |conversation, cx| {
2516 conversation.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
2517 })
2518 .unwrap();
2519 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
2520 assert_eq!(
2521 messages(&conversation, cx),
2522 vec![
2523 (message_1.id, Role::User, 0..4),
2524 (message_2.id, Role::User, 4..8),
2525 (message_3.id, Role::User, 8..12),
2526 (message_4.id, Role::User, 12..12)
2527 ]
2528 );
2529 assert_eq!(
2530 message_ids_for_offsets(&conversation, &[0, 4, 8, 12], cx),
2531 [message_1.id, message_2.id, message_3.id, message_4.id]
2532 );
2533
2534 fn message_ids_for_offsets(
2535 conversation: &ModelHandle<Conversation>,
2536 offsets: &[usize],
2537 cx: &AppContext,
2538 ) -> Vec<MessageId> {
2539 conversation
2540 .read(cx)
2541 .messages_for_offsets(offsets.iter().copied(), cx)
2542 .into_iter()
2543 .map(|message| message.id)
2544 .collect()
2545 }
2546 }
2547
2548 #[gpui::test]
2549 fn test_serialization(cx: &mut AppContext) {
2550 cx.set_global(SettingsStore::test(cx));
2551 init(cx);
2552 let registry = Arc::new(LanguageRegistry::test());
2553 let conversation =
2554 cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
2555 let buffer = conversation.read(cx).buffer.clone();
2556 let message_0 = conversation.read(cx).message_anchors[0].id;
2557 let message_1 = conversation.update(cx, |conversation, cx| {
2558 conversation
2559 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
2560 .unwrap()
2561 });
2562 let message_2 = conversation.update(cx, |conversation, cx| {
2563 conversation
2564 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2565 .unwrap()
2566 });
2567 buffer.update(cx, |buffer, cx| {
2568 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
2569 buffer.finalize_last_transaction();
2570 });
2571 let _message_3 = conversation.update(cx, |conversation, cx| {
2572 conversation
2573 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
2574 .unwrap()
2575 });
2576 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2577 assert_eq!(buffer.read(cx).text(), "a\nb\nc\n");
2578 assert_eq!(
2579 messages(&conversation, cx),
2580 [
2581 (message_0, Role::User, 0..2),
2582 (message_1.id, Role::Assistant, 2..6),
2583 (message_2.id, Role::System, 6..6),
2584 ]
2585 );
2586
2587 let deserialized_conversation = cx.add_model(|cx| {
2588 Conversation::deserialize(
2589 conversation.read(cx).serialize(cx),
2590 Default::default(),
2591 Default::default(),
2592 registry.clone(),
2593 cx,
2594 )
2595 });
2596 let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone();
2597 assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n");
2598 assert_eq!(
2599 messages(&deserialized_conversation, cx),
2600 [
2601 (message_0, Role::User, 0..2),
2602 (message_1.id, Role::Assistant, 2..6),
2603 (message_2.id, Role::System, 6..6),
2604 ]
2605 );
2606 }
2607
2608 fn messages(
2609 conversation: &ModelHandle<Conversation>,
2610 cx: &AppContext,
2611 ) -> Vec<(MessageId, Role, Range<usize>)> {
2612 conversation
2613 .read(cx)
2614 .messages(cx)
2615 .map(|message| (message.id, message.role, message.offset_range))
2616 .collect()
2617 }
2618}