1use crate::{
2 assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
3 MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
4 RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
5};
6use anyhow::{anyhow, Result};
7use chrono::{DateTime, Local};
8use collections::{HashMap, HashSet};
9use editor::{
10 display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
11 scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
12 Anchor, Editor, ToOffset,
13};
14use fs::Fs;
15use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
16use gpui::{
17 actions,
18 elements::*,
19 executor::Background,
20 geometry::vector::{vec2f, Vector2F},
21 platform::{CursorStyle, MouseButton},
22 Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
23 Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
24};
25use isahc::{http::StatusCode, Request, RequestExt};
26use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
27use search::BufferSearchBar;
28use serde::Deserialize;
29use settings::SettingsStore;
30use std::{
31 cell::RefCell,
32 cmp, env,
33 fmt::Write,
34 io, iter,
35 ops::Range,
36 path::{Path, PathBuf},
37 rc::Rc,
38 sync::Arc,
39 time::Duration,
40};
41use theme::AssistantStyle;
42use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
43use workspace::{
44 dock::{DockPosition, Panel},
45 searchable::Direction,
46 Save, ToggleZoom, Toolbar, Workspace,
47};
48
49const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
50
51actions!(
52 assistant,
53 [
54 NewConversation,
55 Assist,
56 Split,
57 CycleMessageRole,
58 QuoteSelection,
59 ToggleFocus,
60 ResetKey,
61 ]
62);
63
64pub fn init(cx: &mut AppContext) {
65 settings::register::<AssistantSettings>(cx);
66 cx.add_action(
67 |this: &mut AssistantPanel,
68 _: &workspace::NewFile,
69 cx: &mut ViewContext<AssistantPanel>| {
70 this.new_conversation(cx);
71 },
72 );
73 cx.add_action(ConversationEditor::assist);
74 cx.capture_action(ConversationEditor::cancel_last_assist);
75 cx.capture_action(ConversationEditor::save);
76 cx.add_action(ConversationEditor::quote_selection);
77 cx.capture_action(ConversationEditor::copy);
78 cx.add_action(ConversationEditor::split);
79 cx.capture_action(ConversationEditor::cycle_message_role);
80 cx.add_action(AssistantPanel::save_api_key);
81 cx.add_action(AssistantPanel::reset_api_key);
82 cx.add_action(AssistantPanel::toggle_zoom);
83 cx.add_action(AssistantPanel::deploy);
84 cx.add_action(AssistantPanel::select_next_match);
85 cx.add_action(AssistantPanel::select_prev_match);
86 cx.add_action(AssistantPanel::handle_editor_cancel);
87 cx.add_action(
88 |workspace: &mut Workspace, _: &ToggleFocus, cx: &mut ViewContext<Workspace>| {
89 workspace.toggle_panel_focus::<AssistantPanel>(cx);
90 },
91 );
92}
93
94#[derive(Debug)]
95pub enum AssistantPanelEvent {
96 ZoomIn,
97 ZoomOut,
98 Focus,
99 Close,
100 DockPositionChanged,
101}
102
103pub struct AssistantPanel {
104 workspace: WeakViewHandle<Workspace>,
105 width: Option<f32>,
106 height: Option<f32>,
107 active_editor_index: Option<usize>,
108 prev_active_editor_index: Option<usize>,
109 editors: Vec<ViewHandle<ConversationEditor>>,
110 saved_conversations: Vec<SavedConversationMetadata>,
111 saved_conversations_list_state: UniformListState,
112 zoomed: bool,
113 has_focus: bool,
114 toolbar: ViewHandle<Toolbar>,
115 api_key: Rc<RefCell<Option<String>>>,
116 api_key_editor: Option<ViewHandle<Editor>>,
117 has_read_credentials: bool,
118 languages: Arc<LanguageRegistry>,
119 fs: Arc<dyn Fs>,
120 subscriptions: Vec<Subscription>,
121 _watch_saved_conversations: Task<Result<()>>,
122}
123
124impl AssistantPanel {
125 pub fn load(
126 workspace: WeakViewHandle<Workspace>,
127 cx: AsyncAppContext,
128 ) -> Task<Result<ViewHandle<Self>>> {
129 cx.spawn(|mut cx| async move {
130 let fs = workspace.read_with(&cx, |workspace, _| workspace.app_state().fs.clone())?;
131 let saved_conversations = SavedConversationMetadata::list(fs.clone())
132 .await
133 .log_err()
134 .unwrap_or_default();
135
136 // TODO: deserialize state.
137 let workspace_handle = workspace.clone();
138 workspace.update(&mut cx, |workspace, cx| {
139 cx.add_view::<Self, _>(|cx| {
140 const CONVERSATION_WATCH_DURATION: Duration = Duration::from_millis(100);
141 let _watch_saved_conversations = cx.spawn(move |this, mut cx| async move {
142 let mut events = fs
143 .watch(&CONVERSATIONS_DIR, CONVERSATION_WATCH_DURATION)
144 .await;
145 while events.next().await.is_some() {
146 let saved_conversations = SavedConversationMetadata::list(fs.clone())
147 .await
148 .log_err()
149 .unwrap_or_default();
150 this.update(&mut cx, |this, cx| {
151 this.saved_conversations = saved_conversations;
152 cx.notify();
153 })
154 .ok();
155 }
156
157 anyhow::Ok(())
158 });
159
160 let toolbar = cx.add_view(|cx| {
161 let mut toolbar = Toolbar::new();
162 toolbar.set_can_navigate(false, cx);
163 toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
164 toolbar
165 });
166 let mut this = Self {
167 workspace: workspace_handle,
168 active_editor_index: Default::default(),
169 prev_active_editor_index: Default::default(),
170 editors: Default::default(),
171 saved_conversations,
172 saved_conversations_list_state: Default::default(),
173 zoomed: false,
174 has_focus: false,
175 toolbar,
176 api_key: Rc::new(RefCell::new(None)),
177 api_key_editor: None,
178 has_read_credentials: false,
179 languages: workspace.app_state().languages.clone(),
180 fs: workspace.app_state().fs.clone(),
181 width: None,
182 height: None,
183 subscriptions: Default::default(),
184 _watch_saved_conversations,
185 };
186
187 let mut old_dock_position = this.position(cx);
188 this.subscriptions =
189 vec![cx.observe_global::<SettingsStore, _>(move |this, cx| {
190 let new_dock_position = this.position(cx);
191 if new_dock_position != old_dock_position {
192 old_dock_position = new_dock_position;
193 cx.emit(AssistantPanelEvent::DockPositionChanged);
194 }
195 cx.notify();
196 })];
197
198 this
199 })
200 })
201 })
202 }
203
204 fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
205 let editor = cx.add_view(|cx| {
206 ConversationEditor::new(
207 self.api_key.clone(),
208 self.languages.clone(),
209 self.fs.clone(),
210 cx,
211 )
212 });
213 self.add_conversation(editor.clone(), cx);
214 editor
215 }
216
217 fn add_conversation(
218 &mut self,
219 editor: ViewHandle<ConversationEditor>,
220 cx: &mut ViewContext<Self>,
221 ) {
222 self.subscriptions
223 .push(cx.subscribe(&editor, Self::handle_conversation_editor_event));
224
225 let conversation = editor.read(cx).conversation.clone();
226 self.subscriptions
227 .push(cx.observe(&conversation, |_, _, cx| cx.notify()));
228
229 let index = self.editors.len();
230 self.editors.push(editor);
231 self.set_active_editor_index(Some(index), cx);
232 }
233
234 fn set_active_editor_index(&mut self, index: Option<usize>, cx: &mut ViewContext<Self>) {
235 self.prev_active_editor_index = self.active_editor_index;
236 self.active_editor_index = index;
237 if let Some(editor) = self.active_editor() {
238 let editor = editor.read(cx).editor.clone();
239 self.toolbar.update(cx, |toolbar, cx| {
240 toolbar.set_active_item(Some(&editor), cx);
241 });
242 if self.has_focus(cx) {
243 cx.focus(&editor);
244 }
245 } else {
246 self.toolbar.update(cx, |toolbar, cx| {
247 toolbar.set_active_item(None, cx);
248 });
249 }
250
251 cx.notify();
252 }
253
254 fn handle_conversation_editor_event(
255 &mut self,
256 _: ViewHandle<ConversationEditor>,
257 event: &ConversationEditorEvent,
258 cx: &mut ViewContext<Self>,
259 ) {
260 match event {
261 ConversationEditorEvent::TabContentChanged => cx.notify(),
262 }
263 }
264
265 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
266 if let Some(api_key) = self
267 .api_key_editor
268 .as_ref()
269 .map(|editor| editor.read(cx).text(cx))
270 {
271 if !api_key.is_empty() {
272 cx.platform()
273 .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
274 .log_err();
275 *self.api_key.borrow_mut() = Some(api_key);
276 self.api_key_editor.take();
277 cx.focus_self();
278 cx.notify();
279 }
280 } else {
281 cx.propagate_action();
282 }
283 }
284
285 fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
286 cx.platform().delete_credentials(OPENAI_API_URL).log_err();
287 self.api_key.take();
288 self.api_key_editor = Some(build_api_key_editor(cx));
289 cx.focus_self();
290 cx.notify();
291 }
292
293 fn toggle_zoom(&mut self, _: &workspace::ToggleZoom, cx: &mut ViewContext<Self>) {
294 if self.zoomed {
295 cx.emit(AssistantPanelEvent::ZoomOut)
296 } else {
297 cx.emit(AssistantPanelEvent::ZoomIn)
298 }
299 }
300
301 fn deploy(&mut self, action: &search::buffer_search::Deploy, cx: &mut ViewContext<Self>) {
302 let mut propagate_action = true;
303 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
304 search_bar.update(cx, |search_bar, cx| {
305 if search_bar.show(cx) {
306 search_bar.search_suggested(cx);
307 if action.focus {
308 search_bar.select_query(cx);
309 cx.focus_self();
310 }
311 propagate_action = false
312 }
313 });
314 }
315 if propagate_action {
316 cx.propagate_action();
317 }
318 }
319
320 fn handle_editor_cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
321 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
322 if !search_bar.read(cx).is_dismissed() {
323 search_bar.update(cx, |search_bar, cx| {
324 search_bar.dismiss(&Default::default(), cx)
325 });
326 return;
327 }
328 }
329 cx.propagate_action();
330 }
331
332 fn select_next_match(&mut self, _: &search::SelectNextMatch, cx: &mut ViewContext<Self>) {
333 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
334 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Next, 1, cx));
335 }
336 }
337
338 fn select_prev_match(&mut self, _: &search::SelectPrevMatch, cx: &mut ViewContext<Self>) {
339 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
340 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Prev, 1, cx));
341 }
342 }
343
344 fn active_editor(&self) -> Option<&ViewHandle<ConversationEditor>> {
345 self.editors.get(self.active_editor_index?)
346 }
347
348 fn render_hamburger_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
349 enum History {}
350 let theme = theme::current(cx);
351 let tooltip_style = theme::current(cx).tooltip.clone();
352 MouseEventHandler::new::<History, _>(0, cx, |state, _| {
353 let style = theme.assistant.hamburger_button.style_for(state);
354 Svg::for_style(style.icon.clone())
355 .contained()
356 .with_style(style.container)
357 })
358 .with_cursor_style(CursorStyle::PointingHand)
359 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
360 if this.active_editor().is_some() {
361 this.set_active_editor_index(None, cx);
362 } else {
363 this.set_active_editor_index(this.prev_active_editor_index, cx);
364 }
365 })
366 .with_tooltip::<History>(1, "History", None, tooltip_style, cx)
367 }
368
369 fn render_editor_tools(&self, cx: &mut ViewContext<Self>) -> Vec<AnyElement<Self>> {
370 if self.active_editor().is_some() {
371 vec![
372 Self::render_split_button(cx).into_any(),
373 Self::render_quote_button(cx).into_any(),
374 Self::render_assist_button(cx).into_any(),
375 ]
376 } else {
377 Default::default()
378 }
379 }
380
381 fn render_split_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
382 let theme = theme::current(cx);
383 let tooltip_style = theme::current(cx).tooltip.clone();
384 MouseEventHandler::new::<Split, _>(0, cx, |state, _| {
385 let style = theme.assistant.split_button.style_for(state);
386 Svg::for_style(style.icon.clone())
387 .contained()
388 .with_style(style.container)
389 })
390 .with_cursor_style(CursorStyle::PointingHand)
391 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
392 if let Some(active_editor) = this.active_editor() {
393 active_editor.update(cx, |editor, cx| editor.split(&Default::default(), cx));
394 }
395 })
396 .with_tooltip::<Split>(
397 1,
398 "Split Message",
399 Some(Box::new(Split)),
400 tooltip_style,
401 cx,
402 )
403 }
404
405 fn render_assist_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
406 let theme = theme::current(cx);
407 let tooltip_style = theme::current(cx).tooltip.clone();
408 MouseEventHandler::new::<Assist, _>(0, cx, |state, _| {
409 let style = theme.assistant.assist_button.style_for(state);
410 Svg::for_style(style.icon.clone())
411 .contained()
412 .with_style(style.container)
413 })
414 .with_cursor_style(CursorStyle::PointingHand)
415 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
416 if let Some(active_editor) = this.active_editor() {
417 active_editor.update(cx, |editor, cx| editor.assist(&Default::default(), cx));
418 }
419 })
420 .with_tooltip::<Assist>(1, "Assist", Some(Box::new(Assist)), tooltip_style, cx)
421 }
422
423 fn render_quote_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
424 let theme = theme::current(cx);
425 let tooltip_style = theme::current(cx).tooltip.clone();
426 MouseEventHandler::new::<QuoteSelection, _>(0, cx, |state, _| {
427 let style = theme.assistant.quote_button.style_for(state);
428 Svg::for_style(style.icon.clone())
429 .contained()
430 .with_style(style.container)
431 })
432 .with_cursor_style(CursorStyle::PointingHand)
433 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
434 if let Some(workspace) = this.workspace.upgrade(cx) {
435 cx.window_context().defer(move |cx| {
436 workspace.update(cx, |workspace, cx| {
437 ConversationEditor::quote_selection(workspace, &Default::default(), cx)
438 });
439 });
440 }
441 })
442 .with_tooltip::<QuoteSelection>(
443 1,
444 "Quote Selection",
445 Some(Box::new(QuoteSelection)),
446 tooltip_style,
447 cx,
448 )
449 }
450
451 fn render_plus_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
452 let theme = theme::current(cx);
453 let tooltip_style = theme::current(cx).tooltip.clone();
454 MouseEventHandler::new::<NewConversation, _>(0, cx, |state, _| {
455 let style = theme.assistant.plus_button.style_for(state);
456 Svg::for_style(style.icon.clone())
457 .contained()
458 .with_style(style.container)
459 })
460 .with_cursor_style(CursorStyle::PointingHand)
461 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
462 this.new_conversation(cx);
463 })
464 .with_tooltip::<NewConversation>(
465 1,
466 "New Conversation",
467 Some(Box::new(NewConversation)),
468 tooltip_style,
469 cx,
470 )
471 }
472
473 fn render_zoom_button(&self, cx: &mut ViewContext<Self>) -> impl Element<Self> {
474 enum ToggleZoomButton {}
475
476 let theme = theme::current(cx);
477 let tooltip_style = theme::current(cx).tooltip.clone();
478 let style = if self.zoomed {
479 &theme.assistant.zoom_out_button
480 } else {
481 &theme.assistant.zoom_in_button
482 };
483
484 MouseEventHandler::new::<ToggleZoomButton, _>(0, cx, |state, _| {
485 let style = style.style_for(state);
486 Svg::for_style(style.icon.clone())
487 .contained()
488 .with_style(style.container)
489 })
490 .with_cursor_style(CursorStyle::PointingHand)
491 .on_click(MouseButton::Left, |_, this, cx| {
492 this.toggle_zoom(&ToggleZoom, cx);
493 })
494 .with_tooltip::<ToggleZoom>(
495 0,
496 if self.zoomed { "Zoom Out" } else { "Zoom In" },
497 Some(Box::new(ToggleZoom)),
498 tooltip_style,
499 cx,
500 )
501 }
502
503 fn render_saved_conversation(
504 &mut self,
505 index: usize,
506 cx: &mut ViewContext<Self>,
507 ) -> impl Element<Self> {
508 let conversation = &self.saved_conversations[index];
509 let path = conversation.path.clone();
510 MouseEventHandler::new::<SavedConversationMetadata, _>(index, cx, move |state, cx| {
511 let style = &theme::current(cx).assistant.saved_conversation;
512 Flex::row()
513 .with_child(
514 Label::new(
515 conversation.mtime.format("%F %I:%M%p").to_string(),
516 style.saved_at.text.clone(),
517 )
518 .aligned()
519 .contained()
520 .with_style(style.saved_at.container),
521 )
522 .with_child(
523 Label::new(conversation.title.clone(), style.title.text.clone())
524 .aligned()
525 .contained()
526 .with_style(style.title.container),
527 )
528 .contained()
529 .with_style(*style.container.style_for(state))
530 })
531 .with_cursor_style(CursorStyle::PointingHand)
532 .on_click(MouseButton::Left, move |_, this, cx| {
533 this.open_conversation(path.clone(), cx)
534 .detach_and_log_err(cx)
535 })
536 }
537
538 fn open_conversation(&mut self, path: PathBuf, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
539 if let Some(ix) = self.editor_index_for_path(&path, cx) {
540 self.set_active_editor_index(Some(ix), cx);
541 return Task::ready(Ok(()));
542 }
543
544 let fs = self.fs.clone();
545 let api_key = self.api_key.clone();
546 let languages = self.languages.clone();
547 cx.spawn(|this, mut cx| async move {
548 let saved_conversation = fs.load(&path).await?;
549 let saved_conversation = serde_json::from_str(&saved_conversation)?;
550 let conversation = cx.add_model(|cx| {
551 Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
552 });
553 this.update(&mut cx, |this, cx| {
554 // If, by the time we've loaded the conversation, the user has already opened
555 // the same conversation, we don't want to open it again.
556 if let Some(ix) = this.editor_index_for_path(&path, cx) {
557 this.set_active_editor_index(Some(ix), cx);
558 } else {
559 let editor = cx
560 .add_view(|cx| ConversationEditor::for_conversation(conversation, fs, cx));
561 this.add_conversation(editor, cx);
562 }
563 })?;
564 Ok(())
565 })
566 }
567
568 fn editor_index_for_path(&self, path: &Path, cx: &AppContext) -> Option<usize> {
569 self.editors
570 .iter()
571 .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
572 }
573}
574
575fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
576 cx.add_view(|cx| {
577 let mut editor = Editor::single_line(
578 Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
579 cx,
580 );
581 editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
582 editor
583 })
584}
585
586impl Entity for AssistantPanel {
587 type Event = AssistantPanelEvent;
588}
589
590impl View for AssistantPanel {
591 fn ui_name() -> &'static str {
592 "AssistantPanel"
593 }
594
595 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
596 let theme = &theme::current(cx);
597 let style = &theme.assistant;
598 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
599 Flex::column()
600 .with_child(
601 Text::new(
602 "Paste your OpenAI API key and press Enter to use the assistant",
603 style.api_key_prompt.text.clone(),
604 )
605 .aligned(),
606 )
607 .with_child(
608 ChildView::new(api_key_editor, cx)
609 .contained()
610 .with_style(style.api_key_editor.container)
611 .aligned(),
612 )
613 .contained()
614 .with_style(style.api_key_prompt.container)
615 .aligned()
616 .into_any()
617 } else {
618 let title = self.active_editor().map(|editor| {
619 Label::new(editor.read(cx).title(cx), style.title.text.clone())
620 .contained()
621 .with_style(style.title.container)
622 .aligned()
623 .left()
624 .flex(1., false)
625 });
626 let mut header = Flex::row()
627 .with_child(Self::render_hamburger_button(cx).aligned())
628 .with_children(title);
629 if self.has_focus {
630 header.add_children(
631 self.render_editor_tools(cx)
632 .into_iter()
633 .map(|tool| tool.aligned().flex_float()),
634 );
635 header.add_child(Self::render_plus_button(cx).aligned().flex_float());
636 header.add_child(self.render_zoom_button(cx).aligned());
637 }
638
639 Flex::column()
640 .with_child(
641 header
642 .contained()
643 .with_style(theme.workspace.tab_bar.container)
644 .expanded()
645 .constrained()
646 .with_height(theme.workspace.tab_bar.height),
647 )
648 .with_children(if self.toolbar.read(cx).hidden() {
649 None
650 } else {
651 Some(ChildView::new(&self.toolbar, cx).expanded())
652 })
653 .with_child(if let Some(editor) = self.active_editor() {
654 ChildView::new(editor, cx).flex(1., true).into_any()
655 } else {
656 UniformList::new(
657 self.saved_conversations_list_state.clone(),
658 self.saved_conversations.len(),
659 cx,
660 |this, range, items, cx| {
661 for ix in range {
662 items.push(this.render_saved_conversation(ix, cx).into_any());
663 }
664 },
665 )
666 .flex(1., true)
667 .into_any()
668 })
669 .into_any()
670 }
671 }
672
673 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
674 self.has_focus = true;
675 self.toolbar
676 .update(cx, |toolbar, cx| toolbar.focus_changed(true, cx));
677 cx.notify();
678 if cx.is_self_focused() {
679 if let Some(editor) = self.active_editor() {
680 cx.focus(editor);
681 } else if let Some(api_key_editor) = self.api_key_editor.as_ref() {
682 cx.focus(api_key_editor);
683 }
684 }
685 }
686
687 fn focus_out(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
688 self.has_focus = false;
689 self.toolbar
690 .update(cx, |toolbar, cx| toolbar.focus_changed(false, cx));
691 cx.notify();
692 }
693}
694
695impl Panel for AssistantPanel {
696 fn position(&self, cx: &WindowContext) -> DockPosition {
697 match settings::get::<AssistantSettings>(cx).dock {
698 AssistantDockPosition::Left => DockPosition::Left,
699 AssistantDockPosition::Bottom => DockPosition::Bottom,
700 AssistantDockPosition::Right => DockPosition::Right,
701 }
702 }
703
704 fn position_is_valid(&self, _: DockPosition) -> bool {
705 true
706 }
707
708 fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
709 settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
710 let dock = match position {
711 DockPosition::Left => AssistantDockPosition::Left,
712 DockPosition::Bottom => AssistantDockPosition::Bottom,
713 DockPosition::Right => AssistantDockPosition::Right,
714 };
715 settings.dock = Some(dock);
716 });
717 }
718
719 fn size(&self, cx: &WindowContext) -> f32 {
720 let settings = settings::get::<AssistantSettings>(cx);
721 match self.position(cx) {
722 DockPosition::Left | DockPosition::Right => {
723 self.width.unwrap_or_else(|| settings.default_width)
724 }
725 DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
726 }
727 }
728
729 fn set_size(&mut self, size: Option<f32>, cx: &mut ViewContext<Self>) {
730 match self.position(cx) {
731 DockPosition::Left | DockPosition::Right => self.width = size,
732 DockPosition::Bottom => self.height = size,
733 }
734 cx.notify();
735 }
736
737 fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
738 matches!(event, AssistantPanelEvent::ZoomIn)
739 }
740
741 fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
742 matches!(event, AssistantPanelEvent::ZoomOut)
743 }
744
745 fn is_zoomed(&self, _: &WindowContext) -> bool {
746 self.zoomed
747 }
748
749 fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
750 self.zoomed = zoomed;
751 cx.notify();
752 }
753
754 fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
755 if active {
756 if self.api_key.borrow().is_none() && !self.has_read_credentials {
757 self.has_read_credentials = true;
758 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
759 Some(api_key)
760 } else if let Some((_, api_key)) = cx
761 .platform()
762 .read_credentials(OPENAI_API_URL)
763 .log_err()
764 .flatten()
765 {
766 String::from_utf8(api_key).log_err()
767 } else {
768 None
769 };
770 if let Some(api_key) = api_key {
771 *self.api_key.borrow_mut() = Some(api_key);
772 } else if self.api_key_editor.is_none() {
773 self.api_key_editor = Some(build_api_key_editor(cx));
774 cx.notify();
775 }
776 }
777
778 if self.editors.is_empty() {
779 self.new_conversation(cx);
780 }
781 }
782 }
783
784 fn icon_path(&self, cx: &WindowContext) -> Option<&'static str> {
785 settings::get::<AssistantSettings>(cx)
786 .button
787 .then(|| "icons/ai.svg")
788 }
789
790 fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
791 ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
792 }
793
794 fn should_change_position_on_event(event: &Self::Event) -> bool {
795 matches!(event, AssistantPanelEvent::DockPositionChanged)
796 }
797
798 fn should_activate_on_event(_: &Self::Event) -> bool {
799 false
800 }
801
802 fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
803 matches!(event, AssistantPanelEvent::Close)
804 }
805
806 fn has_focus(&self, _: &WindowContext) -> bool {
807 self.has_focus
808 }
809
810 fn is_focus_event(event: &Self::Event) -> bool {
811 matches!(event, AssistantPanelEvent::Focus)
812 }
813}
814
815enum ConversationEvent {
816 MessagesEdited,
817 SummaryChanged,
818 StreamedCompletion,
819}
820
821#[derive(Default)]
822struct Summary {
823 text: String,
824 done: bool,
825}
826
827struct Conversation {
828 buffer: ModelHandle<Buffer>,
829 message_anchors: Vec<MessageAnchor>,
830 messages_metadata: HashMap<MessageId, MessageMetadata>,
831 next_message_id: MessageId,
832 summary: Option<Summary>,
833 pending_summary: Task<Option<()>>,
834 completion_count: usize,
835 pending_completions: Vec<PendingCompletion>,
836 model: OpenAIModel,
837 token_count: Option<usize>,
838 max_token_count: usize,
839 pending_token_count: Task<Option<()>>,
840 api_key: Rc<RefCell<Option<String>>>,
841 pending_save: Task<Result<()>>,
842 path: Option<PathBuf>,
843 _subscriptions: Vec<Subscription>,
844}
845
846impl Entity for Conversation {
847 type Event = ConversationEvent;
848}
849
850impl Conversation {
851 fn new(
852 api_key: Rc<RefCell<Option<String>>>,
853 language_registry: Arc<LanguageRegistry>,
854 cx: &mut ModelContext<Self>,
855 ) -> Self {
856 let markdown = language_registry.language_for_name("Markdown");
857 let buffer = cx.add_model(|cx| {
858 let mut buffer = Buffer::new(0, cx.model_id() as u64, "");
859 buffer.set_language_registry(language_registry);
860 cx.spawn_weak(|buffer, mut cx| async move {
861 let markdown = markdown.await?;
862 let buffer = buffer
863 .upgrade(&cx)
864 .ok_or_else(|| anyhow!("buffer was dropped"))?;
865 buffer.update(&mut cx, |buffer: &mut Buffer, cx| {
866 buffer.set_language(Some(markdown), cx)
867 });
868 anyhow::Ok(())
869 })
870 .detach_and_log_err(cx);
871 buffer
872 });
873
874 let settings = settings::get::<AssistantSettings>(cx);
875 let model = settings.default_open_ai_model.clone();
876
877 let mut this = Self {
878 message_anchors: Default::default(),
879 messages_metadata: Default::default(),
880 next_message_id: Default::default(),
881 summary: None,
882 pending_summary: Task::ready(None),
883 completion_count: Default::default(),
884 pending_completions: Default::default(),
885 token_count: None,
886 max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
887 pending_token_count: Task::ready(None),
888 model: model.clone(),
889 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
890 pending_save: Task::ready(Ok(())),
891 path: None,
892 api_key,
893 buffer,
894 };
895 let message = MessageAnchor {
896 id: MessageId(post_inc(&mut this.next_message_id.0)),
897 start: language::Anchor::MIN,
898 };
899 this.message_anchors.push(message.clone());
900 this.messages_metadata.insert(
901 message.id,
902 MessageMetadata {
903 role: Role::User,
904 sent_at: Local::now(),
905 status: MessageStatus::Done,
906 },
907 );
908
909 this.count_remaining_tokens(cx);
910 this
911 }
912
913 fn serialize(&self, cx: &AppContext) -> SavedConversation {
914 SavedConversation {
915 zed: "conversation".into(),
916 version: SavedConversation::VERSION.into(),
917 text: self.buffer.read(cx).text(),
918 message_metadata: self.messages_metadata.clone(),
919 messages: self
920 .messages(cx)
921 .map(|message| SavedMessage {
922 id: message.id,
923 start: message.offset_range.start,
924 })
925 .collect(),
926 summary: self
927 .summary
928 .as_ref()
929 .map(|summary| summary.text.clone())
930 .unwrap_or_default(),
931 model: self.model.clone(),
932 }
933 }
934
935 fn deserialize(
936 saved_conversation: SavedConversation,
937 path: PathBuf,
938 api_key: Rc<RefCell<Option<String>>>,
939 language_registry: Arc<LanguageRegistry>,
940 cx: &mut ModelContext<Self>,
941 ) -> Self {
942 let model = saved_conversation.model;
943 let markdown = language_registry.language_for_name("Markdown");
944 let mut message_anchors = Vec::new();
945 let mut next_message_id = MessageId(0);
946 let buffer = cx.add_model(|cx| {
947 let mut buffer = Buffer::new(0, cx.model_id() as u64, saved_conversation.text);
948 for message in saved_conversation.messages {
949 message_anchors.push(MessageAnchor {
950 id: message.id,
951 start: buffer.anchor_before(message.start),
952 });
953 next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
954 }
955 buffer.set_language_registry(language_registry);
956 cx.spawn_weak(|buffer, mut cx| async move {
957 let markdown = markdown.await?;
958 let buffer = buffer
959 .upgrade(&cx)
960 .ok_or_else(|| anyhow!("buffer was dropped"))?;
961 buffer.update(&mut cx, |buffer: &mut Buffer, cx| {
962 buffer.set_language(Some(markdown), cx)
963 });
964 anyhow::Ok(())
965 })
966 .detach_and_log_err(cx);
967 buffer
968 });
969
970 let mut this = Self {
971 message_anchors,
972 messages_metadata: saved_conversation.message_metadata,
973 next_message_id,
974 summary: Some(Summary {
975 text: saved_conversation.summary,
976 done: true,
977 }),
978 pending_summary: Task::ready(None),
979 completion_count: Default::default(),
980 pending_completions: Default::default(),
981 token_count: None,
982 max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
983 pending_token_count: Task::ready(None),
984 model,
985 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
986 pending_save: Task::ready(Ok(())),
987 path: Some(path),
988 api_key,
989 buffer,
990 };
991 this.count_remaining_tokens(cx);
992 this
993 }
994
995 fn handle_buffer_event(
996 &mut self,
997 _: ModelHandle<Buffer>,
998 event: &language::Event,
999 cx: &mut ModelContext<Self>,
1000 ) {
1001 match event {
1002 language::Event::Edited => {
1003 self.count_remaining_tokens(cx);
1004 cx.emit(ConversationEvent::MessagesEdited);
1005 }
1006 _ => {}
1007 }
1008 }
1009
1010 fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
1011 let messages = self
1012 .messages(cx)
1013 .into_iter()
1014 .filter_map(|message| {
1015 Some(tiktoken_rs::ChatCompletionRequestMessage {
1016 role: match message.role {
1017 Role::User => "user".into(),
1018 Role::Assistant => "assistant".into(),
1019 Role::System => "system".into(),
1020 },
1021 content: self
1022 .buffer
1023 .read(cx)
1024 .text_for_range(message.offset_range)
1025 .collect(),
1026 name: None,
1027 })
1028 })
1029 .collect::<Vec<_>>();
1030 let model = self.model.clone();
1031 self.pending_token_count = cx.spawn_weak(|this, mut cx| {
1032 async move {
1033 cx.background().timer(Duration::from_millis(200)).await;
1034 let token_count = cx
1035 .background()
1036 .spawn(async move {
1037 tiktoken_rs::num_tokens_from_messages(&model.full_name(), &messages)
1038 })
1039 .await?;
1040
1041 this.upgrade(&cx)
1042 .ok_or_else(|| anyhow!("conversation was dropped"))?
1043 .update(&mut cx, |this, cx| {
1044 this.max_token_count =
1045 tiktoken_rs::model::get_context_size(&this.model.full_name());
1046 this.token_count = Some(token_count);
1047 cx.notify()
1048 });
1049 anyhow::Ok(())
1050 }
1051 .log_err()
1052 });
1053 }
1054
1055 fn remaining_tokens(&self) -> Option<isize> {
1056 Some(self.max_token_count as isize - self.token_count? as isize)
1057 }
1058
1059 fn set_model(&mut self, model: OpenAIModel, cx: &mut ModelContext<Self>) {
1060 self.model = model;
1061 self.count_remaining_tokens(cx);
1062 cx.notify();
1063 }
1064
1065 fn assist(
1066 &mut self,
1067 selected_messages: HashSet<MessageId>,
1068 cx: &mut ModelContext<Self>,
1069 ) -> Vec<MessageAnchor> {
1070 let mut user_messages = Vec::new();
1071 let mut tasks = Vec::new();
1072
1073 let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
1074 message
1075 .start
1076 .is_valid(self.buffer.read(cx))
1077 .then_some(message.id)
1078 });
1079
1080 for selected_message_id in selected_messages {
1081 let selected_message_role =
1082 if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
1083 metadata.role
1084 } else {
1085 continue;
1086 };
1087
1088 if selected_message_role == Role::Assistant {
1089 if let Some(user_message) = self.insert_message_after(
1090 selected_message_id,
1091 Role::User,
1092 MessageStatus::Done,
1093 cx,
1094 ) {
1095 user_messages.push(user_message);
1096 } else {
1097 continue;
1098 }
1099 } else {
1100 let request = OpenAIRequest {
1101 model: self.model.full_name().to_string(),
1102 messages: self
1103 .messages(cx)
1104 .filter(|message| matches!(message.status, MessageStatus::Done))
1105 .flat_map(|message| {
1106 let mut system_message = None;
1107 if message.id == selected_message_id {
1108 system_message = Some(RequestMessage {
1109 role: Role::System,
1110 content: concat!(
1111 "Treat the following messages as additional knowledge you have learned about, ",
1112 "but act as if they were not part of this conversation. That is, treat them ",
1113 "as if the user didn't see them and couldn't possibly inquire about them."
1114 ).into()
1115 });
1116 }
1117
1118 Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
1119 })
1120 .chain(Some(RequestMessage {
1121 role: Role::System,
1122 content: format!(
1123 "Direct your reply to message with id {}. Do not include a [Message X] header.",
1124 selected_message_id.0
1125 ),
1126 }))
1127 .collect(),
1128 stream: true,
1129 };
1130
1131 let Some(api_key) = self.api_key.borrow().clone() else {
1132 continue;
1133 };
1134 let stream = stream_completion(api_key, cx.background().clone(), request);
1135 let assistant_message = self
1136 .insert_message_after(
1137 selected_message_id,
1138 Role::Assistant,
1139 MessageStatus::Pending,
1140 cx,
1141 )
1142 .unwrap();
1143
1144 // Queue up the user's next reply
1145 if Some(selected_message_id) == last_message_id {
1146 let user_message = self
1147 .insert_message_after(
1148 assistant_message.id,
1149 Role::User,
1150 MessageStatus::Done,
1151 cx,
1152 )
1153 .unwrap();
1154 user_messages.push(user_message);
1155 }
1156
1157 tasks.push(cx.spawn_weak({
1158 |this, mut cx| async move {
1159 let assistant_message_id = assistant_message.id;
1160 let stream_completion = async {
1161 let mut messages = stream.await?;
1162
1163 while let Some(message) = messages.next().await {
1164 let mut message = message?;
1165 if let Some(choice) = message.choices.pop() {
1166 this.upgrade(&cx)
1167 .ok_or_else(|| anyhow!("conversation was dropped"))?
1168 .update(&mut cx, |this, cx| {
1169 let text: Arc<str> = choice.delta.content?.into();
1170 let message_ix = this.message_anchors.iter().position(
1171 |message| message.id == assistant_message_id,
1172 )?;
1173 this.buffer.update(cx, |buffer, cx| {
1174 let offset = this.message_anchors[message_ix + 1..]
1175 .iter()
1176 .find(|message| message.start.is_valid(buffer))
1177 .map_or(buffer.len(), |message| {
1178 message
1179 .start
1180 .to_offset(buffer)
1181 .saturating_sub(1)
1182 });
1183 buffer.edit([(offset..offset, text)], None, cx);
1184 });
1185 cx.emit(ConversationEvent::StreamedCompletion);
1186
1187 Some(())
1188 });
1189 }
1190 smol::future::yield_now().await;
1191 }
1192
1193 this.upgrade(&cx)
1194 .ok_or_else(|| anyhow!("conversation was dropped"))?
1195 .update(&mut cx, |this, cx| {
1196 this.pending_completions.retain(|completion| {
1197 completion.id != this.completion_count
1198 });
1199 this.summarize(cx);
1200 });
1201
1202 anyhow::Ok(())
1203 };
1204
1205 let result = stream_completion.await;
1206 if let Some(this) = this.upgrade(&cx) {
1207 this.update(&mut cx, |this, cx| {
1208 if let Some(metadata) =
1209 this.messages_metadata.get_mut(&assistant_message.id)
1210 {
1211 match result {
1212 Ok(_) => {
1213 metadata.status = MessageStatus::Done;
1214 }
1215 Err(error) => {
1216 metadata.status = MessageStatus::Error(
1217 error.to_string().trim().into(),
1218 );
1219 }
1220 }
1221 cx.notify();
1222 }
1223 });
1224 }
1225 }
1226 }));
1227 }
1228 }
1229
1230 if !tasks.is_empty() {
1231 self.pending_completions.push(PendingCompletion {
1232 id: post_inc(&mut self.completion_count),
1233 _tasks: tasks,
1234 });
1235 }
1236
1237 user_messages
1238 }
1239
1240 fn cancel_last_assist(&mut self) -> bool {
1241 self.pending_completions.pop().is_some()
1242 }
1243
1244 fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
1245 for id in ids {
1246 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
1247 metadata.role.cycle();
1248 cx.emit(ConversationEvent::MessagesEdited);
1249 cx.notify();
1250 }
1251 }
1252 }
1253
1254 fn insert_message_after(
1255 &mut self,
1256 message_id: MessageId,
1257 role: Role,
1258 status: MessageStatus,
1259 cx: &mut ModelContext<Self>,
1260 ) -> Option<MessageAnchor> {
1261 if let Some(prev_message_ix) = self
1262 .message_anchors
1263 .iter()
1264 .position(|message| message.id == message_id)
1265 {
1266 // Find the next valid message after the one we were given.
1267 let mut next_message_ix = prev_message_ix + 1;
1268 while let Some(next_message) = self.message_anchors.get(next_message_ix) {
1269 if next_message.start.is_valid(self.buffer.read(cx)) {
1270 break;
1271 }
1272 next_message_ix += 1;
1273 }
1274
1275 let start = self.buffer.update(cx, |buffer, cx| {
1276 let offset = self
1277 .message_anchors
1278 .get(next_message_ix)
1279 .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
1280 buffer.edit([(offset..offset, "\n")], None, cx);
1281 buffer.anchor_before(offset + 1)
1282 });
1283 let message = MessageAnchor {
1284 id: MessageId(post_inc(&mut self.next_message_id.0)),
1285 start,
1286 };
1287 self.message_anchors
1288 .insert(next_message_ix, message.clone());
1289 self.messages_metadata.insert(
1290 message.id,
1291 MessageMetadata {
1292 role,
1293 sent_at: Local::now(),
1294 status,
1295 },
1296 );
1297 cx.emit(ConversationEvent::MessagesEdited);
1298 Some(message)
1299 } else {
1300 None
1301 }
1302 }
1303
1304 fn split_message(
1305 &mut self,
1306 range: Range<usize>,
1307 cx: &mut ModelContext<Self>,
1308 ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
1309 let start_message = self.message_for_offset(range.start, cx);
1310 let end_message = self.message_for_offset(range.end, cx);
1311 if let Some((start_message, end_message)) = start_message.zip(end_message) {
1312 // Prevent splitting when range spans multiple messages.
1313 if start_message.id != end_message.id {
1314 return (None, None);
1315 }
1316
1317 let message = start_message;
1318 let role = message.role;
1319 let mut edited_buffer = false;
1320
1321 let mut suffix_start = None;
1322 if range.start > message.offset_range.start && range.end < message.offset_range.end - 1
1323 {
1324 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
1325 suffix_start = Some(range.end + 1);
1326 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
1327 suffix_start = Some(range.end);
1328 }
1329 }
1330
1331 let suffix = if let Some(suffix_start) = suffix_start {
1332 MessageAnchor {
1333 id: MessageId(post_inc(&mut self.next_message_id.0)),
1334 start: self.buffer.read(cx).anchor_before(suffix_start),
1335 }
1336 } else {
1337 self.buffer.update(cx, |buffer, cx| {
1338 buffer.edit([(range.end..range.end, "\n")], None, cx);
1339 });
1340 edited_buffer = true;
1341 MessageAnchor {
1342 id: MessageId(post_inc(&mut self.next_message_id.0)),
1343 start: self.buffer.read(cx).anchor_before(range.end + 1),
1344 }
1345 };
1346
1347 self.message_anchors
1348 .insert(message.index_range.end + 1, suffix.clone());
1349 self.messages_metadata.insert(
1350 suffix.id,
1351 MessageMetadata {
1352 role,
1353 sent_at: Local::now(),
1354 status: MessageStatus::Done,
1355 },
1356 );
1357
1358 let new_messages =
1359 if range.start == range.end || range.start == message.offset_range.start {
1360 (None, Some(suffix))
1361 } else {
1362 let mut prefix_end = None;
1363 if range.start > message.offset_range.start
1364 && range.end < message.offset_range.end - 1
1365 {
1366 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
1367 prefix_end = Some(range.start + 1);
1368 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
1369 == Some('\n')
1370 {
1371 prefix_end = Some(range.start);
1372 }
1373 }
1374
1375 let selection = if let Some(prefix_end) = prefix_end {
1376 cx.emit(ConversationEvent::MessagesEdited);
1377 MessageAnchor {
1378 id: MessageId(post_inc(&mut self.next_message_id.0)),
1379 start: self.buffer.read(cx).anchor_before(prefix_end),
1380 }
1381 } else {
1382 self.buffer.update(cx, |buffer, cx| {
1383 buffer.edit([(range.start..range.start, "\n")], None, cx)
1384 });
1385 edited_buffer = true;
1386 MessageAnchor {
1387 id: MessageId(post_inc(&mut self.next_message_id.0)),
1388 start: self.buffer.read(cx).anchor_before(range.end + 1),
1389 }
1390 };
1391
1392 self.message_anchors
1393 .insert(message.index_range.end + 1, selection.clone());
1394 self.messages_metadata.insert(
1395 selection.id,
1396 MessageMetadata {
1397 role,
1398 sent_at: Local::now(),
1399 status: MessageStatus::Done,
1400 },
1401 );
1402 (Some(selection), Some(suffix))
1403 };
1404
1405 if !edited_buffer {
1406 cx.emit(ConversationEvent::MessagesEdited);
1407 }
1408 new_messages
1409 } else {
1410 (None, None)
1411 }
1412 }
1413
1414 fn summarize(&mut self, cx: &mut ModelContext<Self>) {
1415 if self.message_anchors.len() >= 2 && self.summary.is_none() {
1416 let api_key = self.api_key.borrow().clone();
1417 if let Some(api_key) = api_key {
1418 let messages = self
1419 .messages(cx)
1420 .take(2)
1421 .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
1422 .chain(Some(RequestMessage {
1423 role: Role::User,
1424 content:
1425 "Summarize the conversation into a short title without punctuation"
1426 .into(),
1427 }));
1428 let request = OpenAIRequest {
1429 model: self.model.full_name().to_string(),
1430 messages: messages.collect(),
1431 stream: true,
1432 };
1433
1434 let stream = stream_completion(api_key, cx.background().clone(), request);
1435 self.pending_summary = cx.spawn(|this, mut cx| {
1436 async move {
1437 let mut messages = stream.await?;
1438
1439 while let Some(message) = messages.next().await {
1440 let mut message = message?;
1441 if let Some(choice) = message.choices.pop() {
1442 let text = choice.delta.content.unwrap_or_default();
1443 this.update(&mut cx, |this, cx| {
1444 this.summary
1445 .get_or_insert(Default::default())
1446 .text
1447 .push_str(&text);
1448 cx.emit(ConversationEvent::SummaryChanged);
1449 });
1450 }
1451 }
1452
1453 this.update(&mut cx, |this, cx| {
1454 if let Some(summary) = this.summary.as_mut() {
1455 summary.done = true;
1456 cx.emit(ConversationEvent::SummaryChanged);
1457 }
1458 });
1459
1460 anyhow::Ok(())
1461 }
1462 .log_err()
1463 });
1464 }
1465 }
1466 }
1467
1468 fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
1469 self.messages_for_offsets([offset], cx).pop()
1470 }
1471
1472 fn messages_for_offsets(
1473 &self,
1474 offsets: impl IntoIterator<Item = usize>,
1475 cx: &AppContext,
1476 ) -> Vec<Message> {
1477 let mut result = Vec::new();
1478
1479 let mut messages = self.messages(cx).peekable();
1480 let mut offsets = offsets.into_iter().peekable();
1481 let mut current_message = messages.next();
1482 while let Some(offset) = offsets.next() {
1483 // Locate the message that contains the offset.
1484 while current_message.as_ref().map_or(false, |message| {
1485 !message.offset_range.contains(&offset) && messages.peek().is_some()
1486 }) {
1487 current_message = messages.next();
1488 }
1489 let Some(message) = current_message.as_ref() else {
1490 break;
1491 };
1492
1493 // Skip offsets that are in the same message.
1494 while offsets.peek().map_or(false, |offset| {
1495 message.offset_range.contains(offset) || messages.peek().is_none()
1496 }) {
1497 offsets.next();
1498 }
1499
1500 result.push(message.clone());
1501 }
1502 result
1503 }
1504
1505 fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
1506 let buffer = self.buffer.read(cx);
1507 let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
1508 iter::from_fn(move || {
1509 while let Some((start_ix, message_anchor)) = message_anchors.next() {
1510 let metadata = self.messages_metadata.get(&message_anchor.id)?;
1511 let message_start = message_anchor.start.to_offset(buffer);
1512 let mut message_end = None;
1513 let mut end_ix = start_ix;
1514 while let Some((_, next_message)) = message_anchors.peek() {
1515 if next_message.start.is_valid(buffer) {
1516 message_end = Some(next_message.start);
1517 break;
1518 } else {
1519 end_ix += 1;
1520 message_anchors.next();
1521 }
1522 }
1523 let message_end = message_end
1524 .unwrap_or(language::Anchor::MAX)
1525 .to_offset(buffer);
1526 return Some(Message {
1527 index_range: start_ix..end_ix,
1528 offset_range: message_start..message_end,
1529 id: message_anchor.id,
1530 anchor: message_anchor.start,
1531 role: metadata.role,
1532 sent_at: metadata.sent_at,
1533 status: metadata.status.clone(),
1534 });
1535 }
1536 None
1537 })
1538 }
1539
1540 fn save(
1541 &mut self,
1542 debounce: Option<Duration>,
1543 fs: Arc<dyn Fs>,
1544 cx: &mut ModelContext<Conversation>,
1545 ) {
1546 self.pending_save = cx.spawn(|this, mut cx| async move {
1547 if let Some(debounce) = debounce {
1548 cx.background().timer(debounce).await;
1549 }
1550
1551 let (old_path, summary) = this.read_with(&cx, |this, _| {
1552 let path = this.path.clone();
1553 let summary = if let Some(summary) = this.summary.as_ref() {
1554 if summary.done {
1555 Some(summary.text.clone())
1556 } else {
1557 None
1558 }
1559 } else {
1560 None
1561 };
1562 (path, summary)
1563 });
1564
1565 if let Some(summary) = summary {
1566 let conversation = this.read_with(&cx, |this, cx| this.serialize(cx));
1567 let path = if let Some(old_path) = old_path {
1568 old_path
1569 } else {
1570 let mut discriminant = 1;
1571 let mut new_path;
1572 loop {
1573 new_path = CONVERSATIONS_DIR.join(&format!(
1574 "{} - {}.zed.json",
1575 summary.trim(),
1576 discriminant
1577 ));
1578 if fs.is_file(&new_path).await {
1579 discriminant += 1;
1580 } else {
1581 break;
1582 }
1583 }
1584 new_path
1585 };
1586
1587 fs.create_dir(CONVERSATIONS_DIR.as_ref()).await?;
1588 fs.atomic_write(path.clone(), serde_json::to_string(&conversation).unwrap())
1589 .await?;
1590 this.update(&mut cx, |this, _| this.path = Some(path));
1591 }
1592
1593 Ok(())
1594 });
1595 }
1596}
1597
1598struct PendingCompletion {
1599 id: usize,
1600 _tasks: Vec<Task<()>>,
1601}
1602
1603enum ConversationEditorEvent {
1604 TabContentChanged,
1605}
1606
1607#[derive(Copy, Clone, Debug, PartialEq)]
1608struct ScrollPosition {
1609 offset_before_cursor: Vector2F,
1610 cursor: Anchor,
1611}
1612
1613struct ConversationEditor {
1614 conversation: ModelHandle<Conversation>,
1615 fs: Arc<dyn Fs>,
1616 editor: ViewHandle<Editor>,
1617 blocks: HashSet<BlockId>,
1618 scroll_position: Option<ScrollPosition>,
1619 _subscriptions: Vec<Subscription>,
1620}
1621
1622impl ConversationEditor {
1623 fn new(
1624 api_key: Rc<RefCell<Option<String>>>,
1625 language_registry: Arc<LanguageRegistry>,
1626 fs: Arc<dyn Fs>,
1627 cx: &mut ViewContext<Self>,
1628 ) -> Self {
1629 let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
1630 Self::for_conversation(conversation, fs, cx)
1631 }
1632
1633 fn for_conversation(
1634 conversation: ModelHandle<Conversation>,
1635 fs: Arc<dyn Fs>,
1636 cx: &mut ViewContext<Self>,
1637 ) -> Self {
1638 let editor = cx.add_view(|cx| {
1639 let mut editor = Editor::for_buffer(conversation.read(cx).buffer.clone(), None, cx);
1640 editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
1641 editor.set_show_gutter(false, cx);
1642 editor.set_show_wrap_guides(false, cx);
1643 editor
1644 });
1645
1646 let _subscriptions = vec![
1647 cx.observe(&conversation, |_, _, cx| cx.notify()),
1648 cx.subscribe(&conversation, Self::handle_conversation_event),
1649 cx.subscribe(&editor, Self::handle_editor_event),
1650 ];
1651
1652 let mut this = Self {
1653 conversation,
1654 editor,
1655 blocks: Default::default(),
1656 scroll_position: None,
1657 fs,
1658 _subscriptions,
1659 };
1660 this.update_message_headers(cx);
1661 this
1662 }
1663
1664 fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
1665 let cursors = self.cursors(cx);
1666
1667 let user_messages = self.conversation.update(cx, |conversation, cx| {
1668 let selected_messages = conversation
1669 .messages_for_offsets(cursors, cx)
1670 .into_iter()
1671 .map(|message| message.id)
1672 .collect();
1673 conversation.assist(selected_messages, cx)
1674 });
1675 let new_selections = user_messages
1676 .iter()
1677 .map(|message| {
1678 let cursor = message
1679 .start
1680 .to_offset(self.conversation.read(cx).buffer.read(cx));
1681 cursor..cursor
1682 })
1683 .collect::<Vec<_>>();
1684 if !new_selections.is_empty() {
1685 self.editor.update(cx, |editor, cx| {
1686 editor.change_selections(
1687 Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
1688 cx,
1689 |selections| selections.select_ranges(new_selections),
1690 );
1691 });
1692 // Avoid scrolling to the new cursor position so the assistant's output is stable.
1693 cx.defer(|this, _| this.scroll_position = None);
1694 }
1695 }
1696
1697 fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
1698 if !self
1699 .conversation
1700 .update(cx, |conversation, _| conversation.cancel_last_assist())
1701 {
1702 cx.propagate_action();
1703 }
1704 }
1705
1706 fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
1707 let cursors = self.cursors(cx);
1708 self.conversation.update(cx, |conversation, cx| {
1709 let messages = conversation
1710 .messages_for_offsets(cursors, cx)
1711 .into_iter()
1712 .map(|message| message.id)
1713 .collect();
1714 conversation.cycle_message_roles(messages, cx)
1715 });
1716 }
1717
1718 fn cursors(&self, cx: &AppContext) -> Vec<usize> {
1719 let selections = self.editor.read(cx).selections.all::<usize>(cx);
1720 selections
1721 .into_iter()
1722 .map(|selection| selection.head())
1723 .collect()
1724 }
1725
1726 fn handle_conversation_event(
1727 &mut self,
1728 _: ModelHandle<Conversation>,
1729 event: &ConversationEvent,
1730 cx: &mut ViewContext<Self>,
1731 ) {
1732 match event {
1733 ConversationEvent::MessagesEdited => {
1734 self.update_message_headers(cx);
1735 self.conversation.update(cx, |conversation, cx| {
1736 conversation.save(Some(Duration::from_millis(500)), self.fs.clone(), cx);
1737 });
1738 }
1739 ConversationEvent::SummaryChanged => {
1740 cx.emit(ConversationEditorEvent::TabContentChanged);
1741 self.conversation.update(cx, |conversation, cx| {
1742 conversation.save(None, self.fs.clone(), cx);
1743 });
1744 }
1745 ConversationEvent::StreamedCompletion => {
1746 self.editor.update(cx, |editor, cx| {
1747 if let Some(scroll_position) = self.scroll_position {
1748 let snapshot = editor.snapshot(cx);
1749 let cursor_point = scroll_position.cursor.to_display_point(&snapshot);
1750 let scroll_top =
1751 cursor_point.row() as f32 - scroll_position.offset_before_cursor.y();
1752 editor.set_scroll_position(
1753 vec2f(scroll_position.offset_before_cursor.x(), scroll_top),
1754 cx,
1755 );
1756 }
1757 });
1758 }
1759 }
1760 }
1761
1762 fn handle_editor_event(
1763 &mut self,
1764 _: ViewHandle<Editor>,
1765 event: &editor::Event,
1766 cx: &mut ViewContext<Self>,
1767 ) {
1768 match event {
1769 editor::Event::ScrollPositionChanged { autoscroll, .. } => {
1770 let cursor_scroll_position = self.cursor_scroll_position(cx);
1771 if *autoscroll {
1772 self.scroll_position = cursor_scroll_position;
1773 } else if self.scroll_position != cursor_scroll_position {
1774 self.scroll_position = None;
1775 }
1776 }
1777 editor::Event::SelectionsChanged { .. } => {
1778 self.scroll_position = self.cursor_scroll_position(cx);
1779 }
1780 _ => {}
1781 }
1782 }
1783
1784 fn cursor_scroll_position(&self, cx: &mut ViewContext<Self>) -> Option<ScrollPosition> {
1785 self.editor.update(cx, |editor, cx| {
1786 let snapshot = editor.snapshot(cx);
1787 let cursor = editor.selections.newest_anchor().head();
1788 let cursor_row = cursor.to_display_point(&snapshot.display_snapshot).row() as f32;
1789 let scroll_position = editor
1790 .scroll_manager
1791 .anchor()
1792 .scroll_position(&snapshot.display_snapshot);
1793
1794 let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1795 if (scroll_position.y()..scroll_bottom).contains(&cursor_row) {
1796 Some(ScrollPosition {
1797 cursor,
1798 offset_before_cursor: vec2f(
1799 scroll_position.x(),
1800 cursor_row - scroll_position.y(),
1801 ),
1802 })
1803 } else {
1804 None
1805 }
1806 })
1807 }
1808
1809 fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
1810 self.editor.update(cx, |editor, cx| {
1811 let buffer = editor.buffer().read(cx).snapshot(cx);
1812 let excerpt_id = *buffer.as_singleton().unwrap().0;
1813 let old_blocks = std::mem::take(&mut self.blocks);
1814 let new_blocks = self
1815 .conversation
1816 .read(cx)
1817 .messages(cx)
1818 .map(|message| BlockProperties {
1819 position: buffer.anchor_in_excerpt(excerpt_id, message.anchor),
1820 height: 2,
1821 style: BlockStyle::Sticky,
1822 render: Arc::new({
1823 let conversation = self.conversation.clone();
1824 // let metadata = message.metadata.clone();
1825 // let message = message.clone();
1826 move |cx| {
1827 enum Sender {}
1828 enum ErrorTooltip {}
1829
1830 let theme = theme::current(cx);
1831 let style = &theme.assistant;
1832 let message_id = message.id;
1833 let sender = MouseEventHandler::new::<Sender, _>(
1834 message_id.0,
1835 cx,
1836 |state, _| match message.role {
1837 Role::User => {
1838 let style = style.user_sender.style_for(state);
1839 Label::new("You", style.text.clone())
1840 .contained()
1841 .with_style(style.container)
1842 }
1843 Role::Assistant => {
1844 let style = style.assistant_sender.style_for(state);
1845 Label::new("Assistant", style.text.clone())
1846 .contained()
1847 .with_style(style.container)
1848 }
1849 Role::System => {
1850 let style = style.system_sender.style_for(state);
1851 Label::new("System", style.text.clone())
1852 .contained()
1853 .with_style(style.container)
1854 }
1855 },
1856 )
1857 .with_cursor_style(CursorStyle::PointingHand)
1858 .on_down(MouseButton::Left, {
1859 let conversation = conversation.clone();
1860 move |_, _, cx| {
1861 conversation.update(cx, |conversation, cx| {
1862 conversation.cycle_message_roles(
1863 HashSet::from_iter(Some(message_id)),
1864 cx,
1865 )
1866 })
1867 }
1868 });
1869
1870 Flex::row()
1871 .with_child(sender.aligned())
1872 .with_child(
1873 Label::new(
1874 message.sent_at.format("%I:%M%P").to_string(),
1875 style.sent_at.text.clone(),
1876 )
1877 .contained()
1878 .with_style(style.sent_at.container)
1879 .aligned(),
1880 )
1881 .with_children(
1882 if let MessageStatus::Error(error) = &message.status {
1883 Some(
1884 Svg::new("icons/circle_x_mark_12.svg")
1885 .with_color(style.error_icon.color)
1886 .constrained()
1887 .with_width(style.error_icon.width)
1888 .contained()
1889 .with_style(style.error_icon.container)
1890 .with_tooltip::<ErrorTooltip>(
1891 message_id.0,
1892 error.to_string(),
1893 None,
1894 theme.tooltip.clone(),
1895 cx,
1896 )
1897 .aligned(),
1898 )
1899 } else {
1900 None
1901 },
1902 )
1903 .aligned()
1904 .left()
1905 .contained()
1906 .with_style(style.message_header)
1907 .into_any()
1908 }
1909 }),
1910 disposition: BlockDisposition::Above,
1911 })
1912 .collect::<Vec<_>>();
1913
1914 editor.remove_blocks(old_blocks, None, cx);
1915 let ids = editor.insert_blocks(new_blocks, None, cx);
1916 self.blocks = HashSet::from_iter(ids);
1917 });
1918 }
1919
1920 fn quote_selection(
1921 workspace: &mut Workspace,
1922 _: &QuoteSelection,
1923 cx: &mut ViewContext<Workspace>,
1924 ) {
1925 let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1926 return;
1927 };
1928 let Some(editor) = workspace
1929 .active_item(cx)
1930 .and_then(|item| item.act_as::<Editor>(cx))
1931 else {
1932 return;
1933 };
1934
1935 let text = editor.read_with(cx, |editor, cx| {
1936 let range = editor.selections.newest::<usize>(cx).range();
1937 let buffer = editor.buffer().read(cx).snapshot(cx);
1938 let start_language = buffer.language_at(range.start);
1939 let end_language = buffer.language_at(range.end);
1940 let language_name = if start_language == end_language {
1941 start_language.map(|language| language.name())
1942 } else {
1943 None
1944 };
1945 let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1946
1947 let selected_text = buffer.text_for_range(range).collect::<String>();
1948 if selected_text.is_empty() {
1949 None
1950 } else {
1951 Some(if language_name == "markdown" {
1952 selected_text
1953 .lines()
1954 .map(|line| format!("> {}", line))
1955 .collect::<Vec<_>>()
1956 .join("\n")
1957 } else {
1958 format!("```{language_name}\n{selected_text}\n```")
1959 })
1960 }
1961 });
1962
1963 // Activate the panel
1964 if !panel.read(cx).has_focus(cx) {
1965 workspace.toggle_panel_focus::<AssistantPanel>(cx);
1966 }
1967
1968 if let Some(text) = text {
1969 panel.update(cx, |panel, cx| {
1970 let conversation = panel
1971 .active_editor()
1972 .cloned()
1973 .unwrap_or_else(|| panel.new_conversation(cx));
1974 conversation.update(cx, |conversation, cx| {
1975 conversation
1976 .editor
1977 .update(cx, |editor, cx| editor.insert(&text, cx))
1978 });
1979 });
1980 }
1981 }
1982
1983 fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1984 let editor = self.editor.read(cx);
1985 let conversation = self.conversation.read(cx);
1986 if editor.selections.count() == 1 {
1987 let selection = editor.selections.newest::<usize>(cx);
1988 let mut copied_text = String::new();
1989 let mut spanned_messages = 0;
1990 for message in conversation.messages(cx) {
1991 if message.offset_range.start >= selection.range().end {
1992 break;
1993 } else if message.offset_range.end >= selection.range().start {
1994 let range = cmp::max(message.offset_range.start, selection.range().start)
1995 ..cmp::min(message.offset_range.end, selection.range().end);
1996 if !range.is_empty() {
1997 spanned_messages += 1;
1998 write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
1999 for chunk in conversation.buffer.read(cx).text_for_range(range) {
2000 copied_text.push_str(&chunk);
2001 }
2002 copied_text.push('\n');
2003 }
2004 }
2005 }
2006
2007 if spanned_messages > 1 {
2008 cx.platform()
2009 .write_to_clipboard(ClipboardItem::new(copied_text));
2010 return;
2011 }
2012 }
2013
2014 cx.propagate_action();
2015 }
2016
2017 fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
2018 self.conversation.update(cx, |conversation, cx| {
2019 let selections = self.editor.read(cx).selections.disjoint_anchors();
2020 for selection in selections.into_iter() {
2021 let buffer = self.editor.read(cx).buffer().read(cx).snapshot(cx);
2022 let range = selection
2023 .map(|endpoint| endpoint.to_offset(&buffer))
2024 .range();
2025 conversation.split_message(range, cx);
2026 }
2027 });
2028 }
2029
2030 fn save(&mut self, _: &Save, cx: &mut ViewContext<Self>) {
2031 self.conversation.update(cx, |conversation, cx| {
2032 conversation.save(None, self.fs.clone(), cx)
2033 });
2034 }
2035
2036 fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
2037 self.conversation.update(cx, |conversation, cx| {
2038 let new_model = conversation.model.cycle();
2039 conversation.set_model(new_model, cx);
2040 });
2041 }
2042
2043 fn title(&self, cx: &AppContext) -> String {
2044 self.conversation
2045 .read(cx)
2046 .summary
2047 .as_ref()
2048 .map(|summary| summary.text.clone())
2049 .unwrap_or_else(|| "New Conversation".into())
2050 }
2051
2052 fn render_current_model(
2053 &self,
2054 style: &AssistantStyle,
2055 cx: &mut ViewContext<Self>,
2056 ) -> impl Element<Self> {
2057 enum Model {}
2058
2059 MouseEventHandler::new::<Model, _>(0, cx, |state, cx| {
2060 let style = style.model.style_for(state);
2061 let model_display_name = self.conversation.read(cx).model.short_name();
2062 Label::new(model_display_name, style.text.clone())
2063 .contained()
2064 .with_style(style.container)
2065 })
2066 .with_cursor_style(CursorStyle::PointingHand)
2067 .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx))
2068 }
2069
2070 fn render_remaining_tokens(
2071 &self,
2072 style: &AssistantStyle,
2073 cx: &mut ViewContext<Self>,
2074 ) -> Option<impl Element<Self>> {
2075 let remaining_tokens = self.conversation.read(cx).remaining_tokens()?;
2076 let remaining_tokens_style = if remaining_tokens <= 0 {
2077 &style.no_remaining_tokens
2078 } else if remaining_tokens <= 500 {
2079 &style.low_remaining_tokens
2080 } else {
2081 &style.remaining_tokens
2082 };
2083 Some(
2084 Label::new(
2085 remaining_tokens.to_string(),
2086 remaining_tokens_style.text.clone(),
2087 )
2088 .contained()
2089 .with_style(remaining_tokens_style.container),
2090 )
2091 }
2092}
2093
2094impl Entity for ConversationEditor {
2095 type Event = ConversationEditorEvent;
2096}
2097
2098impl View for ConversationEditor {
2099 fn ui_name() -> &'static str {
2100 "ConversationEditor"
2101 }
2102
2103 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
2104 let theme = &theme::current(cx).assistant;
2105 Stack::new()
2106 .with_child(
2107 ChildView::new(&self.editor, cx)
2108 .contained()
2109 .with_style(theme.container),
2110 )
2111 .with_child(
2112 Flex::row()
2113 .with_child(self.render_current_model(theme, cx))
2114 .with_children(self.render_remaining_tokens(theme, cx))
2115 .aligned()
2116 .top()
2117 .right(),
2118 )
2119 .into_any()
2120 }
2121
2122 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
2123 if cx.is_self_focused() {
2124 cx.focus(&self.editor);
2125 }
2126 }
2127}
2128
2129#[derive(Clone, Debug)]
2130struct MessageAnchor {
2131 id: MessageId,
2132 start: language::Anchor,
2133}
2134
2135#[derive(Clone, Debug)]
2136pub struct Message {
2137 offset_range: Range<usize>,
2138 index_range: Range<usize>,
2139 id: MessageId,
2140 anchor: language::Anchor,
2141 role: Role,
2142 sent_at: DateTime<Local>,
2143 status: MessageStatus,
2144}
2145
2146impl Message {
2147 fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
2148 let mut content = format!("[Message {}]\n", self.id.0).to_string();
2149 content.extend(buffer.text_for_range(self.offset_range.clone()));
2150 RequestMessage {
2151 role: self.role,
2152 content: content.trim_end().into(),
2153 }
2154 }
2155}
2156
2157async fn stream_completion(
2158 api_key: String,
2159 executor: Arc<Background>,
2160 mut request: OpenAIRequest,
2161) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
2162 request.stream = true;
2163
2164 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
2165
2166 let json_data = serde_json::to_string(&request)?;
2167 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
2168 .header("Content-Type", "application/json")
2169 .header("Authorization", format!("Bearer {}", api_key))
2170 .body(json_data)?
2171 .send_async()
2172 .await?;
2173
2174 let status = response.status();
2175 if status == StatusCode::OK {
2176 executor
2177 .spawn(async move {
2178 let mut lines = BufReader::new(response.body_mut()).lines();
2179
2180 fn parse_line(
2181 line: Result<String, io::Error>,
2182 ) -> Result<Option<OpenAIResponseStreamEvent>> {
2183 if let Some(data) = line?.strip_prefix("data: ") {
2184 let event = serde_json::from_str(&data)?;
2185 Ok(Some(event))
2186 } else {
2187 Ok(None)
2188 }
2189 }
2190
2191 while let Some(line) = lines.next().await {
2192 if let Some(event) = parse_line(line).transpose() {
2193 let done = event.as_ref().map_or(false, |event| {
2194 event
2195 .choices
2196 .last()
2197 .map_or(false, |choice| choice.finish_reason.is_some())
2198 });
2199 if tx.unbounded_send(event).is_err() {
2200 break;
2201 }
2202
2203 if done {
2204 break;
2205 }
2206 }
2207 }
2208
2209 anyhow::Ok(())
2210 })
2211 .detach();
2212
2213 Ok(rx)
2214 } else {
2215 let mut body = String::new();
2216 response.body_mut().read_to_string(&mut body).await?;
2217
2218 #[derive(Deserialize)]
2219 struct OpenAIResponse {
2220 error: OpenAIError,
2221 }
2222
2223 #[derive(Deserialize)]
2224 struct OpenAIError {
2225 message: String,
2226 }
2227
2228 match serde_json::from_str::<OpenAIResponse>(&body) {
2229 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
2230 "Failed to connect to OpenAI API: {}",
2231 response.error.message,
2232 )),
2233
2234 _ => Err(anyhow!(
2235 "Failed to connect to OpenAI API: {} {}",
2236 response.status(),
2237 body,
2238 )),
2239 }
2240 }
2241}
2242
2243#[cfg(test)]
2244mod tests {
2245 use super::*;
2246 use crate::MessageId;
2247 use gpui::AppContext;
2248
2249 #[gpui::test]
2250 fn test_inserting_and_removing_messages(cx: &mut AppContext) {
2251 cx.set_global(SettingsStore::test(cx));
2252 init(cx);
2253 let registry = Arc::new(LanguageRegistry::test());
2254 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2255 let buffer = conversation.read(cx).buffer.clone();
2256
2257 let message_1 = conversation.read(cx).message_anchors[0].clone();
2258 assert_eq!(
2259 messages(&conversation, cx),
2260 vec![(message_1.id, Role::User, 0..0)]
2261 );
2262
2263 let message_2 = conversation.update(cx, |conversation, cx| {
2264 conversation
2265 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
2266 .unwrap()
2267 });
2268 assert_eq!(
2269 messages(&conversation, cx),
2270 vec![
2271 (message_1.id, Role::User, 0..1),
2272 (message_2.id, Role::Assistant, 1..1)
2273 ]
2274 );
2275
2276 buffer.update(cx, |buffer, cx| {
2277 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
2278 });
2279 assert_eq!(
2280 messages(&conversation, cx),
2281 vec![
2282 (message_1.id, Role::User, 0..2),
2283 (message_2.id, Role::Assistant, 2..3)
2284 ]
2285 );
2286
2287 let message_3 = conversation.update(cx, |conversation, cx| {
2288 conversation
2289 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2290 .unwrap()
2291 });
2292 assert_eq!(
2293 messages(&conversation, cx),
2294 vec![
2295 (message_1.id, Role::User, 0..2),
2296 (message_2.id, Role::Assistant, 2..4),
2297 (message_3.id, Role::User, 4..4)
2298 ]
2299 );
2300
2301 let message_4 = conversation.update(cx, |conversation, cx| {
2302 conversation
2303 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2304 .unwrap()
2305 });
2306 assert_eq!(
2307 messages(&conversation, cx),
2308 vec![
2309 (message_1.id, Role::User, 0..2),
2310 (message_2.id, Role::Assistant, 2..4),
2311 (message_4.id, Role::User, 4..5),
2312 (message_3.id, Role::User, 5..5),
2313 ]
2314 );
2315
2316 buffer.update(cx, |buffer, cx| {
2317 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
2318 });
2319 assert_eq!(
2320 messages(&conversation, cx),
2321 vec![
2322 (message_1.id, Role::User, 0..2),
2323 (message_2.id, Role::Assistant, 2..4),
2324 (message_4.id, Role::User, 4..6),
2325 (message_3.id, Role::User, 6..7),
2326 ]
2327 );
2328
2329 // Deleting across message boundaries merges the messages.
2330 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
2331 assert_eq!(
2332 messages(&conversation, cx),
2333 vec![
2334 (message_1.id, Role::User, 0..3),
2335 (message_3.id, Role::User, 3..4),
2336 ]
2337 );
2338
2339 // Undoing the deletion should also undo the merge.
2340 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2341 assert_eq!(
2342 messages(&conversation, cx),
2343 vec![
2344 (message_1.id, Role::User, 0..2),
2345 (message_2.id, Role::Assistant, 2..4),
2346 (message_4.id, Role::User, 4..6),
2347 (message_3.id, Role::User, 6..7),
2348 ]
2349 );
2350
2351 // Redoing the deletion should also redo the merge.
2352 buffer.update(cx, |buffer, cx| buffer.redo(cx));
2353 assert_eq!(
2354 messages(&conversation, cx),
2355 vec![
2356 (message_1.id, Role::User, 0..3),
2357 (message_3.id, Role::User, 3..4),
2358 ]
2359 );
2360
2361 // Ensure we can still insert after a merged message.
2362 let message_5 = conversation.update(cx, |conversation, cx| {
2363 conversation
2364 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2365 .unwrap()
2366 });
2367 assert_eq!(
2368 messages(&conversation, cx),
2369 vec![
2370 (message_1.id, Role::User, 0..3),
2371 (message_5.id, Role::System, 3..4),
2372 (message_3.id, Role::User, 4..5)
2373 ]
2374 );
2375 }
2376
2377 #[gpui::test]
2378 fn test_message_splitting(cx: &mut AppContext) {
2379 cx.set_global(SettingsStore::test(cx));
2380 init(cx);
2381 let registry = Arc::new(LanguageRegistry::test());
2382 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2383 let buffer = conversation.read(cx).buffer.clone();
2384
2385 let message_1 = conversation.read(cx).message_anchors[0].clone();
2386 assert_eq!(
2387 messages(&conversation, cx),
2388 vec![(message_1.id, Role::User, 0..0)]
2389 );
2390
2391 buffer.update(cx, |buffer, cx| {
2392 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
2393 });
2394
2395 let (_, message_2) =
2396 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2397 let message_2 = message_2.unwrap();
2398
2399 // We recycle newlines in the middle of a split message
2400 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
2401 assert_eq!(
2402 messages(&conversation, cx),
2403 vec![
2404 (message_1.id, Role::User, 0..4),
2405 (message_2.id, Role::User, 4..16),
2406 ]
2407 );
2408
2409 let (_, message_3) =
2410 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2411 let message_3 = message_3.unwrap();
2412
2413 // We don't recycle newlines at the end of a split message
2414 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2415 assert_eq!(
2416 messages(&conversation, cx),
2417 vec![
2418 (message_1.id, Role::User, 0..4),
2419 (message_3.id, Role::User, 4..5),
2420 (message_2.id, Role::User, 5..17),
2421 ]
2422 );
2423
2424 let (_, message_4) =
2425 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2426 let message_4 = message_4.unwrap();
2427 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2428 assert_eq!(
2429 messages(&conversation, cx),
2430 vec![
2431 (message_1.id, Role::User, 0..4),
2432 (message_3.id, Role::User, 4..5),
2433 (message_2.id, Role::User, 5..9),
2434 (message_4.id, Role::User, 9..17),
2435 ]
2436 );
2437
2438 let (_, message_5) =
2439 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2440 let message_5 = message_5.unwrap();
2441 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
2442 assert_eq!(
2443 messages(&conversation, cx),
2444 vec![
2445 (message_1.id, Role::User, 0..4),
2446 (message_3.id, Role::User, 4..5),
2447 (message_2.id, Role::User, 5..9),
2448 (message_4.id, Role::User, 9..10),
2449 (message_5.id, Role::User, 10..18),
2450 ]
2451 );
2452
2453 let (message_6, message_7) = conversation.update(cx, |conversation, cx| {
2454 conversation.split_message(14..16, cx)
2455 });
2456 let message_6 = message_6.unwrap();
2457 let message_7 = message_7.unwrap();
2458 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
2459 assert_eq!(
2460 messages(&conversation, cx),
2461 vec![
2462 (message_1.id, Role::User, 0..4),
2463 (message_3.id, Role::User, 4..5),
2464 (message_2.id, Role::User, 5..9),
2465 (message_4.id, Role::User, 9..10),
2466 (message_5.id, Role::User, 10..14),
2467 (message_6.id, Role::User, 14..17),
2468 (message_7.id, Role::User, 17..19),
2469 ]
2470 );
2471 }
2472
2473 #[gpui::test]
2474 fn test_messages_for_offsets(cx: &mut AppContext) {
2475 cx.set_global(SettingsStore::test(cx));
2476 init(cx);
2477 let registry = Arc::new(LanguageRegistry::test());
2478 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2479 let buffer = conversation.read(cx).buffer.clone();
2480
2481 let message_1 = conversation.read(cx).message_anchors[0].clone();
2482 assert_eq!(
2483 messages(&conversation, cx),
2484 vec![(message_1.id, Role::User, 0..0)]
2485 );
2486
2487 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
2488 let message_2 = conversation
2489 .update(cx, |conversation, cx| {
2490 conversation.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
2491 })
2492 .unwrap();
2493 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
2494
2495 let message_3 = conversation
2496 .update(cx, |conversation, cx| {
2497 conversation.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2498 })
2499 .unwrap();
2500 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
2501
2502 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
2503 assert_eq!(
2504 messages(&conversation, cx),
2505 vec![
2506 (message_1.id, Role::User, 0..4),
2507 (message_2.id, Role::User, 4..8),
2508 (message_3.id, Role::User, 8..11)
2509 ]
2510 );
2511
2512 assert_eq!(
2513 message_ids_for_offsets(&conversation, &[0, 4, 9], cx),
2514 [message_1.id, message_2.id, message_3.id]
2515 );
2516 assert_eq!(
2517 message_ids_for_offsets(&conversation, &[0, 1, 11], cx),
2518 [message_1.id, message_3.id]
2519 );
2520
2521 let message_4 = conversation
2522 .update(cx, |conversation, cx| {
2523 conversation.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
2524 })
2525 .unwrap();
2526 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
2527 assert_eq!(
2528 messages(&conversation, cx),
2529 vec![
2530 (message_1.id, Role::User, 0..4),
2531 (message_2.id, Role::User, 4..8),
2532 (message_3.id, Role::User, 8..12),
2533 (message_4.id, Role::User, 12..12)
2534 ]
2535 );
2536 assert_eq!(
2537 message_ids_for_offsets(&conversation, &[0, 4, 8, 12], cx),
2538 [message_1.id, message_2.id, message_3.id, message_4.id]
2539 );
2540
2541 fn message_ids_for_offsets(
2542 conversation: &ModelHandle<Conversation>,
2543 offsets: &[usize],
2544 cx: &AppContext,
2545 ) -> Vec<MessageId> {
2546 conversation
2547 .read(cx)
2548 .messages_for_offsets(offsets.iter().copied(), cx)
2549 .into_iter()
2550 .map(|message| message.id)
2551 .collect()
2552 }
2553 }
2554
2555 #[gpui::test]
2556 fn test_serialization(cx: &mut AppContext) {
2557 cx.set_global(SettingsStore::test(cx));
2558 init(cx);
2559 let registry = Arc::new(LanguageRegistry::test());
2560 let conversation =
2561 cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
2562 let buffer = conversation.read(cx).buffer.clone();
2563 let message_0 = conversation.read(cx).message_anchors[0].id;
2564 let message_1 = conversation.update(cx, |conversation, cx| {
2565 conversation
2566 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
2567 .unwrap()
2568 });
2569 let message_2 = conversation.update(cx, |conversation, cx| {
2570 conversation
2571 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2572 .unwrap()
2573 });
2574 buffer.update(cx, |buffer, cx| {
2575 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
2576 buffer.finalize_last_transaction();
2577 });
2578 let _message_3 = conversation.update(cx, |conversation, cx| {
2579 conversation
2580 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
2581 .unwrap()
2582 });
2583 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2584 assert_eq!(buffer.read(cx).text(), "a\nb\nc\n");
2585 assert_eq!(
2586 messages(&conversation, cx),
2587 [
2588 (message_0, Role::User, 0..2),
2589 (message_1.id, Role::Assistant, 2..6),
2590 (message_2.id, Role::System, 6..6),
2591 ]
2592 );
2593
2594 let deserialized_conversation = cx.add_model(|cx| {
2595 Conversation::deserialize(
2596 conversation.read(cx).serialize(cx),
2597 Default::default(),
2598 Default::default(),
2599 registry.clone(),
2600 cx,
2601 )
2602 });
2603 let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone();
2604 assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n");
2605 assert_eq!(
2606 messages(&deserialized_conversation, cx),
2607 [
2608 (message_0, Role::User, 0..2),
2609 (message_1.id, Role::Assistant, 2..6),
2610 (message_2.id, Role::System, 6..6),
2611 ]
2612 );
2613 }
2614
2615 fn messages(
2616 conversation: &ModelHandle<Conversation>,
2617 cx: &AppContext,
2618 ) -> Vec<(MessageId, Role, Range<usize>)> {
2619 conversation
2620 .read(cx)
2621 .messages(cx)
2622 .map(|message| (message.id, message.role, message.offset_range))
2623 .collect()
2624 }
2625}