1use crate::{
2 assistant_settings::{AssistantDockPosition, AssistantSettings},
3 OpenAIRequest, OpenAIResponseStreamEvent, RequestMessage, Role, SavedConversation,
4 SavedConversationMetadata,
5};
6use anyhow::{anyhow, Result};
7use chrono::{DateTime, Local};
8use collections::{HashMap, HashSet};
9use editor::{
10 display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
11 scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
12 Anchor, Editor, ToOffset,
13};
14use fs::Fs;
15use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
16use gpui::{
17 actions,
18 elements::*,
19 executor::Background,
20 geometry::vector::{vec2f, Vector2F},
21 platform::{CursorStyle, MouseButton},
22 Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
23 Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
24};
25use isahc::{http::StatusCode, Request, RequestExt};
26use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
27use serde::Deserialize;
28use settings::SettingsStore;
29use std::{
30 borrow::Cow, cell::RefCell, cmp, env, fmt::Write, io, iter, ops::Range, path::PathBuf, rc::Rc,
31 sync::Arc, time::Duration,
32};
33use util::{
34 channel::ReleaseChannel, paths::CONVERSATIONS_DIR, post_inc, truncate_and_trailoff, ResultExt,
35 TryFutureExt,
36};
37use workspace::{
38 dock::{DockPosition, Panel},
39 item::Item,
40 pane, Pane, Save, Workspace,
41};
42
43const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
44
45actions!(
46 assistant,
47 [
48 NewContext,
49 Assist,
50 Split,
51 CycleMessageRole,
52 QuoteSelection,
53 ToggleFocus,
54 ResetKey,
55 ]
56);
57
58pub fn init(cx: &mut AppContext) {
59 if *util::channel::RELEASE_CHANNEL == ReleaseChannel::Stable {
60 cx.update_default_global::<collections::CommandPaletteFilter, _, _>(move |filter, _cx| {
61 filter.filtered_namespaces.insert("assistant");
62 });
63 }
64
65 settings::register::<AssistantSettings>(cx);
66 cx.add_action(
67 |workspace: &mut Workspace, _: &NewContext, cx: &mut ViewContext<Workspace>| {
68 if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
69 this.update(cx, |this, cx| this.add_context(cx))
70 }
71
72 workspace.focus_panel::<AssistantPanel>(cx);
73 },
74 );
75 cx.add_action(AssistantEditor::assist);
76 cx.capture_action(AssistantEditor::cancel_last_assist);
77 cx.capture_action(AssistantEditor::save);
78 cx.add_action(AssistantEditor::quote_selection);
79 cx.capture_action(AssistantEditor::copy);
80 cx.capture_action(AssistantEditor::split);
81 cx.capture_action(AssistantEditor::cycle_message_role);
82 cx.add_action(AssistantPanel::save_api_key);
83 cx.add_action(AssistantPanel::reset_api_key);
84 cx.add_action(
85 |workspace: &mut Workspace, _: &ToggleFocus, cx: &mut ViewContext<Workspace>| {
86 workspace.toggle_panel_focus::<AssistantPanel>(cx);
87 },
88 );
89}
90
91pub enum AssistantPanelEvent {
92 ZoomIn,
93 ZoomOut,
94 Focus,
95 Close,
96 DockPositionChanged,
97}
98
99pub struct AssistantPanel {
100 width: Option<f32>,
101 height: Option<f32>,
102 pane: ViewHandle<Pane>,
103 api_key: Rc<RefCell<Option<String>>>,
104 api_key_editor: Option<ViewHandle<Editor>>,
105 has_read_credentials: bool,
106 languages: Arc<LanguageRegistry>,
107 fs: Arc<dyn Fs>,
108 subscriptions: Vec<Subscription>,
109 saved_conversations: Vec<SavedConversationMetadata>,
110 _watch_saved_conversations: Task<Result<()>>,
111}
112
113impl AssistantPanel {
114 pub fn load(
115 workspace: WeakViewHandle<Workspace>,
116 cx: AsyncAppContext,
117 ) -> Task<Result<ViewHandle<Self>>> {
118 cx.spawn(|mut cx| async move {
119 let fs = workspace.read_with(&cx, |workspace, _| workspace.app_state().fs.clone())?;
120 let saved_conversations = SavedConversationMetadata::list(fs.clone())
121 .await
122 .log_err()
123 .unwrap_or_default();
124
125 // TODO: deserialize state.
126 workspace.update(&mut cx, |workspace, cx| {
127 cx.add_view::<Self, _>(|cx| {
128 let weak_self = cx.weak_handle();
129 let pane = cx.add_view(|cx| {
130 let mut pane = Pane::new(
131 workspace.weak_handle(),
132 workspace.project().clone(),
133 workspace.app_state().background_actions,
134 Default::default(),
135 cx,
136 );
137 pane.set_can_split(false, cx);
138 pane.set_can_navigate(false, cx);
139 pane.on_can_drop(move |_, _| false);
140 pane.set_render_tab_bar_buttons(cx, move |pane, cx| {
141 let weak_self = weak_self.clone();
142 Flex::row()
143 .with_child(Pane::render_tab_bar_button(
144 0,
145 "icons/plus_12.svg",
146 false,
147 Some(("New Context".into(), Some(Box::new(NewContext)))),
148 cx,
149 move |_, cx| {
150 let weak_self = weak_self.clone();
151 cx.window_context().defer(move |cx| {
152 if let Some(this) = weak_self.upgrade(cx) {
153 this.update(cx, |this, cx| this.add_context(cx));
154 }
155 })
156 },
157 None,
158 ))
159 .with_child(Pane::render_tab_bar_button(
160 1,
161 if pane.is_zoomed() {
162 "icons/minimize_8.svg"
163 } else {
164 "icons/maximize_8.svg"
165 },
166 pane.is_zoomed(),
167 Some((
168 "Toggle Zoom".into(),
169 Some(Box::new(workspace::ToggleZoom)),
170 )),
171 cx,
172 move |pane, cx| pane.toggle_zoom(&Default::default(), cx),
173 None,
174 ))
175 .into_any()
176 });
177 let buffer_search_bar = cx.add_view(search::BufferSearchBar::new);
178 pane.toolbar()
179 .update(cx, |toolbar, cx| toolbar.add_item(buffer_search_bar, cx));
180 pane
181 });
182
183 const CONVERSATION_WATCH_DURATION: Duration = Duration::from_millis(100);
184 let _watch_saved_conversations = cx.spawn(move |this, mut cx| async move {
185 let mut events = fs
186 .watch(&CONVERSATIONS_DIR, CONVERSATION_WATCH_DURATION)
187 .await;
188 while events.next().await.is_some() {
189 let saved_conversations = SavedConversationMetadata::list(fs.clone())
190 .await
191 .log_err()
192 .unwrap_or_default();
193 this.update(&mut cx, |this, _| {
194 this.saved_conversations = saved_conversations
195 })
196 .ok();
197 }
198
199 anyhow::Ok(())
200 });
201
202 let mut this = Self {
203 pane,
204 api_key: Rc::new(RefCell::new(None)),
205 api_key_editor: None,
206 has_read_credentials: false,
207 languages: workspace.app_state().languages.clone(),
208 fs: workspace.app_state().fs.clone(),
209 width: None,
210 height: None,
211 subscriptions: Default::default(),
212 saved_conversations,
213 _watch_saved_conversations,
214 };
215
216 let mut old_dock_position = this.position(cx);
217 this.subscriptions = vec![
218 cx.observe(&this.pane, |_, _, cx| cx.notify()),
219 cx.subscribe(&this.pane, Self::handle_pane_event),
220 cx.observe_global::<SettingsStore, _>(move |this, cx| {
221 let new_dock_position = this.position(cx);
222 if new_dock_position != old_dock_position {
223 old_dock_position = new_dock_position;
224 cx.emit(AssistantPanelEvent::DockPositionChanged);
225 }
226 }),
227 ];
228
229 this
230 })
231 })
232 })
233 }
234
235 fn handle_pane_event(
236 &mut self,
237 _pane: ViewHandle<Pane>,
238 event: &pane::Event,
239 cx: &mut ViewContext<Self>,
240 ) {
241 match event {
242 pane::Event::ZoomIn => cx.emit(AssistantPanelEvent::ZoomIn),
243 pane::Event::ZoomOut => cx.emit(AssistantPanelEvent::ZoomOut),
244 pane::Event::Focus => cx.emit(AssistantPanelEvent::Focus),
245 pane::Event::Remove => cx.emit(AssistantPanelEvent::Close),
246 _ => {}
247 }
248 }
249
250 fn add_context(&mut self, cx: &mut ViewContext<Self>) {
251 let focus = self.has_focus(cx);
252 let editor = cx.add_view(|cx| {
253 AssistantEditor::new(
254 self.api_key.clone(),
255 self.languages.clone(),
256 self.fs.clone(),
257 cx,
258 )
259 });
260 self.subscriptions
261 .push(cx.subscribe(&editor, Self::handle_assistant_editor_event));
262 self.pane.update(cx, |pane, cx| {
263 pane.add_item(Box::new(editor), true, focus, None, cx)
264 });
265 }
266
267 fn handle_assistant_editor_event(
268 &mut self,
269 _: ViewHandle<AssistantEditor>,
270 event: &AssistantEditorEvent,
271 cx: &mut ViewContext<Self>,
272 ) {
273 match event {
274 AssistantEditorEvent::TabContentChanged => self.pane.update(cx, |_, cx| cx.notify()),
275 }
276 }
277
278 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
279 if let Some(api_key) = self
280 .api_key_editor
281 .as_ref()
282 .map(|editor| editor.read(cx).text(cx))
283 {
284 if !api_key.is_empty() {
285 cx.platform()
286 .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
287 .log_err();
288 *self.api_key.borrow_mut() = Some(api_key);
289 self.api_key_editor.take();
290 cx.focus_self();
291 cx.notify();
292 }
293 } else {
294 cx.propagate_action();
295 }
296 }
297
298 fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
299 cx.platform().delete_credentials(OPENAI_API_URL).log_err();
300 self.api_key.take();
301 self.api_key_editor = Some(build_api_key_editor(cx));
302 cx.focus_self();
303 cx.notify();
304 }
305}
306
307fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
308 cx.add_view(|cx| {
309 let mut editor = Editor::single_line(
310 Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
311 cx,
312 );
313 editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
314 editor
315 })
316}
317
318impl Entity for AssistantPanel {
319 type Event = AssistantPanelEvent;
320}
321
322impl View for AssistantPanel {
323 fn ui_name() -> &'static str {
324 "AssistantPanel"
325 }
326
327 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
328 let style = &theme::current(cx).assistant;
329 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
330 Flex::column()
331 .with_child(
332 Text::new(
333 "Paste your OpenAI API key and press Enter to use the assistant",
334 style.api_key_prompt.text.clone(),
335 )
336 .aligned(),
337 )
338 .with_child(
339 ChildView::new(api_key_editor, cx)
340 .contained()
341 .with_style(style.api_key_editor.container)
342 .aligned(),
343 )
344 .contained()
345 .with_style(style.api_key_prompt.container)
346 .aligned()
347 .into_any()
348 } else {
349 ChildView::new(&self.pane, cx).into_any()
350 }
351 }
352
353 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
354 if cx.is_self_focused() {
355 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
356 cx.focus(api_key_editor);
357 } else {
358 cx.focus(&self.pane);
359 }
360 }
361 }
362}
363
364impl Panel for AssistantPanel {
365 fn position(&self, cx: &WindowContext) -> DockPosition {
366 match settings::get::<AssistantSettings>(cx).dock {
367 AssistantDockPosition::Left => DockPosition::Left,
368 AssistantDockPosition::Bottom => DockPosition::Bottom,
369 AssistantDockPosition::Right => DockPosition::Right,
370 }
371 }
372
373 fn position_is_valid(&self, _: DockPosition) -> bool {
374 true
375 }
376
377 fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
378 settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
379 let dock = match position {
380 DockPosition::Left => AssistantDockPosition::Left,
381 DockPosition::Bottom => AssistantDockPosition::Bottom,
382 DockPosition::Right => AssistantDockPosition::Right,
383 };
384 settings.dock = Some(dock);
385 });
386 }
387
388 fn size(&self, cx: &WindowContext) -> f32 {
389 let settings = settings::get::<AssistantSettings>(cx);
390 match self.position(cx) {
391 DockPosition::Left | DockPosition::Right => {
392 self.width.unwrap_or_else(|| settings.default_width)
393 }
394 DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
395 }
396 }
397
398 fn set_size(&mut self, size: f32, cx: &mut ViewContext<Self>) {
399 match self.position(cx) {
400 DockPosition::Left | DockPosition::Right => self.width = Some(size),
401 DockPosition::Bottom => self.height = Some(size),
402 }
403 cx.notify();
404 }
405
406 fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
407 matches!(event, AssistantPanelEvent::ZoomIn)
408 }
409
410 fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
411 matches!(event, AssistantPanelEvent::ZoomOut)
412 }
413
414 fn is_zoomed(&self, cx: &WindowContext) -> bool {
415 self.pane.read(cx).is_zoomed()
416 }
417
418 fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
419 self.pane.update(cx, |pane, cx| pane.set_zoomed(zoomed, cx));
420 }
421
422 fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
423 if active {
424 if self.api_key.borrow().is_none() && !self.has_read_credentials {
425 self.has_read_credentials = true;
426 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
427 Some(api_key)
428 } else if let Some((_, api_key)) = cx
429 .platform()
430 .read_credentials(OPENAI_API_URL)
431 .log_err()
432 .flatten()
433 {
434 String::from_utf8(api_key).log_err()
435 } else {
436 None
437 };
438 if let Some(api_key) = api_key {
439 *self.api_key.borrow_mut() = Some(api_key);
440 } else if self.api_key_editor.is_none() {
441 self.api_key_editor = Some(build_api_key_editor(cx));
442 cx.notify();
443 }
444 }
445
446 if self.pane.read(cx).items_len() == 0 {
447 self.add_context(cx);
448 }
449 }
450 }
451
452 fn icon_path(&self) -> &'static str {
453 "icons/robot_14.svg"
454 }
455
456 fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
457 ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
458 }
459
460 fn should_change_position_on_event(event: &Self::Event) -> bool {
461 matches!(event, AssistantPanelEvent::DockPositionChanged)
462 }
463
464 fn should_activate_on_event(_: &Self::Event) -> bool {
465 false
466 }
467
468 fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
469 matches!(event, AssistantPanelEvent::Close)
470 }
471
472 fn has_focus(&self, cx: &WindowContext) -> bool {
473 self.pane.read(cx).has_focus()
474 || self
475 .api_key_editor
476 .as_ref()
477 .map_or(false, |editor| editor.is_focused(cx))
478 }
479
480 fn is_focus_event(event: &Self::Event) -> bool {
481 matches!(event, AssistantPanelEvent::Focus)
482 }
483}
484
485enum AssistantEvent {
486 MessagesEdited,
487 SummaryChanged,
488 StreamedCompletion,
489}
490
491#[derive(Default)]
492struct Summary {
493 text: String,
494 done: bool,
495}
496
497struct Assistant {
498 buffer: ModelHandle<Buffer>,
499 message_anchors: Vec<MessageAnchor>,
500 messages_metadata: HashMap<MessageId, MessageMetadata>,
501 next_message_id: MessageId,
502 summary: Option<Summary>,
503 pending_summary: Task<Option<()>>,
504 completion_count: usize,
505 pending_completions: Vec<PendingCompletion>,
506 model: String,
507 token_count: Option<usize>,
508 max_token_count: usize,
509 pending_token_count: Task<Option<()>>,
510 api_key: Rc<RefCell<Option<String>>>,
511 pending_save: Task<Result<()>>,
512 path: Option<PathBuf>,
513 _subscriptions: Vec<Subscription>,
514}
515
516impl Entity for Assistant {
517 type Event = AssistantEvent;
518}
519
520impl Assistant {
521 fn new(
522 api_key: Rc<RefCell<Option<String>>>,
523 language_registry: Arc<LanguageRegistry>,
524 cx: &mut ModelContext<Self>,
525 ) -> Self {
526 let model = "gpt-3.5-turbo-0613";
527 let markdown = language_registry.language_for_name("Markdown");
528 let buffer = cx.add_model(|cx| {
529 let mut buffer = Buffer::new(0, "", cx);
530 buffer.set_language_registry(language_registry);
531 cx.spawn_weak(|buffer, mut cx| async move {
532 let markdown = markdown.await?;
533 let buffer = buffer
534 .upgrade(&cx)
535 .ok_or_else(|| anyhow!("buffer was dropped"))?;
536 buffer.update(&mut cx, |buffer, cx| {
537 buffer.set_language(Some(markdown), cx)
538 });
539 anyhow::Ok(())
540 })
541 .detach_and_log_err(cx);
542 buffer
543 });
544
545 let mut this = Self {
546 message_anchors: Default::default(),
547 messages_metadata: Default::default(),
548 next_message_id: Default::default(),
549 summary: None,
550 pending_summary: Task::ready(None),
551 completion_count: Default::default(),
552 pending_completions: Default::default(),
553 token_count: None,
554 max_token_count: tiktoken_rs::model::get_context_size(model),
555 pending_token_count: Task::ready(None),
556 model: model.into(),
557 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
558 pending_save: Task::ready(Ok(())),
559 path: None,
560 api_key,
561 buffer,
562 };
563 let message = MessageAnchor {
564 id: MessageId(post_inc(&mut this.next_message_id.0)),
565 start: language::Anchor::MIN,
566 };
567 this.message_anchors.push(message.clone());
568 this.messages_metadata.insert(
569 message.id,
570 MessageMetadata {
571 role: Role::User,
572 sent_at: Local::now(),
573 status: MessageStatus::Done,
574 },
575 );
576
577 this.count_remaining_tokens(cx);
578 this
579 }
580
581 fn handle_buffer_event(
582 &mut self,
583 _: ModelHandle<Buffer>,
584 event: &language::Event,
585 cx: &mut ModelContext<Self>,
586 ) {
587 match event {
588 language::Event::Edited => {
589 self.count_remaining_tokens(cx);
590 cx.emit(AssistantEvent::MessagesEdited);
591 }
592 _ => {}
593 }
594 }
595
596 fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
597 let messages = self
598 .messages(cx)
599 .into_iter()
600 .filter_map(|message| {
601 Some(tiktoken_rs::ChatCompletionRequestMessage {
602 role: match message.role {
603 Role::User => "user".into(),
604 Role::Assistant => "assistant".into(),
605 Role::System => "system".into(),
606 },
607 content: self.buffer.read(cx).text_for_range(message.range).collect(),
608 name: None,
609 })
610 })
611 .collect::<Vec<_>>();
612 let model = self.model.clone();
613 self.pending_token_count = cx.spawn_weak(|this, mut cx| {
614 async move {
615 cx.background().timer(Duration::from_millis(200)).await;
616 let token_count = cx
617 .background()
618 .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
619 .await?;
620
621 this.upgrade(&cx)
622 .ok_or_else(|| anyhow!("assistant was dropped"))?
623 .update(&mut cx, |this, cx| {
624 this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
625 this.token_count = Some(token_count);
626 cx.notify()
627 });
628 anyhow::Ok(())
629 }
630 .log_err()
631 });
632 }
633
634 fn remaining_tokens(&self) -> Option<isize> {
635 Some(self.max_token_count as isize - self.token_count? as isize)
636 }
637
638 fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
639 self.model = model;
640 self.count_remaining_tokens(cx);
641 cx.notify();
642 }
643
644 fn assist(
645 &mut self,
646 selected_messages: HashSet<MessageId>,
647 cx: &mut ModelContext<Self>,
648 ) -> Vec<MessageAnchor> {
649 let mut user_messages = Vec::new();
650 let mut tasks = Vec::new();
651 for selected_message_id in selected_messages {
652 let selected_message_role =
653 if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
654 metadata.role
655 } else {
656 continue;
657 };
658
659 if selected_message_role == Role::Assistant {
660 if let Some(user_message) = self.insert_message_after(
661 selected_message_id,
662 Role::User,
663 MessageStatus::Done,
664 cx,
665 ) {
666 user_messages.push(user_message);
667 } else {
668 continue;
669 }
670 } else {
671 let request = OpenAIRequest {
672 model: self.model.clone(),
673 messages: self
674 .messages(cx)
675 .filter(|message| matches!(message.status, MessageStatus::Done))
676 .flat_map(|message| {
677 let mut system_message = None;
678 if message.id == selected_message_id {
679 system_message = Some(RequestMessage {
680 role: Role::System,
681 content: concat!(
682 "Treat the following messages as additional knowledge you have learned about, ",
683 "but act as if they were not part of this conversation. That is, treat them ",
684 "as if the user didn't see them and couldn't possibly inquire about them."
685 ).into()
686 });
687 }
688
689 Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
690 })
691 .chain(Some(RequestMessage {
692 role: Role::System,
693 content: format!(
694 "Direct your reply to message with id {}. Do not include a [Message X] header.",
695 selected_message_id.0
696 ),
697 }))
698 .collect(),
699 stream: true,
700 };
701
702 let Some(api_key) = self.api_key.borrow().clone() else { continue };
703 let stream = stream_completion(api_key, cx.background().clone(), request);
704 let assistant_message = self
705 .insert_message_after(
706 selected_message_id,
707 Role::Assistant,
708 MessageStatus::Pending,
709 cx,
710 )
711 .unwrap();
712
713 tasks.push(cx.spawn_weak({
714 |this, mut cx| async move {
715 let assistant_message_id = assistant_message.id;
716 let stream_completion = async {
717 let mut messages = stream.await?;
718
719 while let Some(message) = messages.next().await {
720 let mut message = message?;
721 if let Some(choice) = message.choices.pop() {
722 this.upgrade(&cx)
723 .ok_or_else(|| anyhow!("assistant was dropped"))?
724 .update(&mut cx, |this, cx| {
725 let text: Arc<str> = choice.delta.content?.into();
726 let message_ix = this.message_anchors.iter().position(
727 |message| message.id == assistant_message_id,
728 )?;
729 this.buffer.update(cx, |buffer, cx| {
730 let offset = this.message_anchors[message_ix + 1..]
731 .iter()
732 .find(|message| message.start.is_valid(buffer))
733 .map_or(buffer.len(), |message| {
734 message
735 .start
736 .to_offset(buffer)
737 .saturating_sub(1)
738 });
739 buffer.edit([(offset..offset, text)], None, cx);
740 });
741 cx.emit(AssistantEvent::StreamedCompletion);
742
743 Some(())
744 });
745 }
746 smol::future::yield_now().await;
747 }
748
749 this.upgrade(&cx)
750 .ok_or_else(|| anyhow!("assistant was dropped"))?
751 .update(&mut cx, |this, cx| {
752 this.pending_completions.retain(|completion| {
753 completion.id != this.completion_count
754 });
755 this.summarize(cx);
756 });
757
758 anyhow::Ok(())
759 };
760
761 let result = stream_completion.await;
762 if let Some(this) = this.upgrade(&cx) {
763 this.update(&mut cx, |this, cx| {
764 if let Some(metadata) =
765 this.messages_metadata.get_mut(&assistant_message.id)
766 {
767 match result {
768 Ok(_) => {
769 metadata.status = MessageStatus::Done;
770 }
771 Err(error) => {
772 metadata.status = MessageStatus::Error(
773 error.to_string().trim().into(),
774 );
775 }
776 }
777 cx.notify();
778 }
779 });
780 }
781 }
782 }));
783 }
784 }
785
786 if !tasks.is_empty() {
787 self.pending_completions.push(PendingCompletion {
788 id: post_inc(&mut self.completion_count),
789 _tasks: tasks,
790 });
791 }
792
793 user_messages
794 }
795
796 fn cancel_last_assist(&mut self) -> bool {
797 self.pending_completions.pop().is_some()
798 }
799
800 fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
801 for id in ids {
802 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
803 metadata.role.cycle();
804 cx.emit(AssistantEvent::MessagesEdited);
805 cx.notify();
806 }
807 }
808 }
809
810 fn insert_message_after(
811 &mut self,
812 message_id: MessageId,
813 role: Role,
814 status: MessageStatus,
815 cx: &mut ModelContext<Self>,
816 ) -> Option<MessageAnchor> {
817 if let Some(prev_message_ix) = self
818 .message_anchors
819 .iter()
820 .position(|message| message.id == message_id)
821 {
822 let start = self.buffer.update(cx, |buffer, cx| {
823 let offset = self.message_anchors[prev_message_ix + 1..]
824 .iter()
825 .find(|message| message.start.is_valid(buffer))
826 .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
827 buffer.edit([(offset..offset, "\n")], None, cx);
828 buffer.anchor_before(offset + 1)
829 });
830 let message = MessageAnchor {
831 id: MessageId(post_inc(&mut self.next_message_id.0)),
832 start,
833 };
834 self.message_anchors
835 .insert(prev_message_ix + 1, message.clone());
836 self.messages_metadata.insert(
837 message.id,
838 MessageMetadata {
839 role,
840 sent_at: Local::now(),
841 status,
842 },
843 );
844 cx.emit(AssistantEvent::MessagesEdited);
845 Some(message)
846 } else {
847 None
848 }
849 }
850
851 fn split_message(
852 &mut self,
853 range: Range<usize>,
854 cx: &mut ModelContext<Self>,
855 ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
856 let start_message = self.message_for_offset(range.start, cx);
857 let end_message = self.message_for_offset(range.end, cx);
858 if let Some((start_message, end_message)) = start_message.zip(end_message) {
859 // Prevent splitting when range spans multiple messages.
860 if start_message.index != end_message.index {
861 return (None, None);
862 }
863
864 let message = start_message;
865 let role = message.role;
866 let mut edited_buffer = false;
867
868 let mut suffix_start = None;
869 if range.start > message.range.start && range.end < message.range.end - 1 {
870 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
871 suffix_start = Some(range.end + 1);
872 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
873 suffix_start = Some(range.end);
874 }
875 }
876
877 let suffix = if let Some(suffix_start) = suffix_start {
878 MessageAnchor {
879 id: MessageId(post_inc(&mut self.next_message_id.0)),
880 start: self.buffer.read(cx).anchor_before(suffix_start),
881 }
882 } else {
883 self.buffer.update(cx, |buffer, cx| {
884 buffer.edit([(range.end..range.end, "\n")], None, cx);
885 });
886 edited_buffer = true;
887 MessageAnchor {
888 id: MessageId(post_inc(&mut self.next_message_id.0)),
889 start: self.buffer.read(cx).anchor_before(range.end + 1),
890 }
891 };
892
893 self.message_anchors
894 .insert(message.index + 1, suffix.clone());
895 self.messages_metadata.insert(
896 suffix.id,
897 MessageMetadata {
898 role,
899 sent_at: Local::now(),
900 status: MessageStatus::Done,
901 },
902 );
903
904 let new_messages = if range.start == range.end || range.start == message.range.start {
905 (None, Some(suffix))
906 } else {
907 let mut prefix_end = None;
908 if range.start > message.range.start && range.end < message.range.end - 1 {
909 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
910 prefix_end = Some(range.start + 1);
911 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
912 == Some('\n')
913 {
914 prefix_end = Some(range.start);
915 }
916 }
917
918 let selection = if let Some(prefix_end) = prefix_end {
919 cx.emit(AssistantEvent::MessagesEdited);
920 MessageAnchor {
921 id: MessageId(post_inc(&mut self.next_message_id.0)),
922 start: self.buffer.read(cx).anchor_before(prefix_end),
923 }
924 } else {
925 self.buffer.update(cx, |buffer, cx| {
926 buffer.edit([(range.start..range.start, "\n")], None, cx)
927 });
928 edited_buffer = true;
929 MessageAnchor {
930 id: MessageId(post_inc(&mut self.next_message_id.0)),
931 start: self.buffer.read(cx).anchor_before(range.end + 1),
932 }
933 };
934
935 self.message_anchors
936 .insert(message.index + 1, selection.clone());
937 self.messages_metadata.insert(
938 selection.id,
939 MessageMetadata {
940 role,
941 sent_at: Local::now(),
942 status: MessageStatus::Done,
943 },
944 );
945 (Some(selection), Some(suffix))
946 };
947
948 if !edited_buffer {
949 cx.emit(AssistantEvent::MessagesEdited);
950 }
951 new_messages
952 } else {
953 (None, None)
954 }
955 }
956
957 fn summarize(&mut self, cx: &mut ModelContext<Self>) {
958 if self.message_anchors.len() >= 2 && self.summary.is_none() {
959 let api_key = self.api_key.borrow().clone();
960 if let Some(api_key) = api_key {
961 let messages = self
962 .messages(cx)
963 .take(2)
964 .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
965 .chain(Some(RequestMessage {
966 role: Role::User,
967 content:
968 "Summarize the conversation into a short title without punctuation"
969 .into(),
970 }));
971 let request = OpenAIRequest {
972 model: self.model.clone(),
973 messages: messages.collect(),
974 stream: true,
975 };
976
977 let stream = stream_completion(api_key, cx.background().clone(), request);
978 self.pending_summary = cx.spawn(|this, mut cx| {
979 async move {
980 let mut messages = stream.await?;
981
982 while let Some(message) = messages.next().await {
983 let mut message = message?;
984 if let Some(choice) = message.choices.pop() {
985 let text = choice.delta.content.unwrap_or_default();
986 this.update(&mut cx, |this, cx| {
987 this.summary
988 .get_or_insert(Default::default())
989 .text
990 .push_str(&text);
991 cx.emit(AssistantEvent::SummaryChanged);
992 });
993 }
994 }
995
996 this.update(&mut cx, |this, cx| {
997 if let Some(summary) = this.summary.as_mut() {
998 summary.done = true;
999 cx.emit(AssistantEvent::SummaryChanged);
1000 }
1001 });
1002
1003 anyhow::Ok(())
1004 }
1005 .log_err()
1006 });
1007 }
1008 }
1009 }
1010
1011 fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
1012 self.messages_for_offsets([offset], cx).pop()
1013 }
1014
1015 fn messages_for_offsets(
1016 &self,
1017 offsets: impl IntoIterator<Item = usize>,
1018 cx: &AppContext,
1019 ) -> Vec<Message> {
1020 let mut result = Vec::new();
1021
1022 let buffer_len = self.buffer.read(cx).len();
1023 let mut messages = self.messages(cx).peekable();
1024 let mut offsets = offsets.into_iter().peekable();
1025 while let Some(offset) = offsets.next() {
1026 // Skip messages that start after the offset.
1027 while messages.peek().map_or(false, |message| {
1028 message.range.end < offset || (message.range.end == offset && offset < buffer_len)
1029 }) {
1030 messages.next();
1031 }
1032 let Some(message) = messages.peek() else { continue };
1033
1034 // Skip offsets that are in the same message.
1035 while offsets.peek().map_or(false, |offset| {
1036 message.range.contains(offset) || message.range.end == buffer_len
1037 }) {
1038 offsets.next();
1039 }
1040
1041 result.push(message.clone());
1042 }
1043 result
1044 }
1045
1046 fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
1047 let buffer = self.buffer.read(cx);
1048 let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
1049 iter::from_fn(move || {
1050 while let Some((ix, message_anchor)) = message_anchors.next() {
1051 let metadata = self.messages_metadata.get(&message_anchor.id)?;
1052 let message_start = message_anchor.start.to_offset(buffer);
1053 let mut message_end = None;
1054 while let Some((_, next_message)) = message_anchors.peek() {
1055 if next_message.start.is_valid(buffer) {
1056 message_end = Some(next_message.start);
1057 break;
1058 } else {
1059 message_anchors.next();
1060 }
1061 }
1062 let message_end = message_end
1063 .unwrap_or(language::Anchor::MAX)
1064 .to_offset(buffer);
1065 return Some(Message {
1066 index: ix,
1067 range: message_start..message_end,
1068 id: message_anchor.id,
1069 anchor: message_anchor.start,
1070 role: metadata.role,
1071 sent_at: metadata.sent_at,
1072 status: metadata.status.clone(),
1073 });
1074 }
1075 None
1076 })
1077 }
1078
1079 fn save(
1080 &mut self,
1081 debounce: Option<Duration>,
1082 fs: Arc<dyn Fs>,
1083 cx: &mut ModelContext<Assistant>,
1084 ) {
1085 self.pending_save = cx.spawn(|this, mut cx| async move {
1086 if let Some(debounce) = debounce {
1087 cx.background().timer(debounce).await;
1088 }
1089
1090 let (old_path, summary) = this.read_with(&cx, |this, _| {
1091 let path = this.path.clone();
1092 let summary = if let Some(summary) = this.summary.as_ref() {
1093 if summary.done {
1094 Some(summary.text.clone())
1095 } else {
1096 None
1097 }
1098 } else {
1099 None
1100 };
1101 (path, summary)
1102 });
1103
1104 if let Some(summary) = summary {
1105 let conversation = SavedConversation {
1106 zed: "conversation".into(),
1107 version: "0.1".into(),
1108 messages: this.read_with(&cx, |this, cx| {
1109 this.messages(cx)
1110 .map(|message| message.to_open_ai_message(this.buffer.read(cx)))
1111 .collect()
1112 }),
1113 };
1114
1115 let path = if let Some(old_path) = old_path {
1116 old_path
1117 } else {
1118 let mut discriminant = 1;
1119 let mut new_path;
1120 loop {
1121 new_path = CONVERSATIONS_DIR.join(&format!(
1122 "{} - {}.zed.json",
1123 summary.trim(),
1124 discriminant
1125 ));
1126 if fs.is_file(&new_path).await {
1127 discriminant += 1;
1128 } else {
1129 break;
1130 }
1131 }
1132 new_path
1133 };
1134
1135 fs.create_dir(CONVERSATIONS_DIR.as_ref()).await?;
1136 fs.atomic_write(path.clone(), serde_json::to_string(&conversation).unwrap())
1137 .await?;
1138 this.update(&mut cx, |this, _| this.path = Some(path));
1139 }
1140
1141 Ok(())
1142 });
1143 }
1144}
1145
1146struct PendingCompletion {
1147 id: usize,
1148 _tasks: Vec<Task<()>>,
1149}
1150
1151enum AssistantEditorEvent {
1152 TabContentChanged,
1153}
1154
1155#[derive(Copy, Clone, Debug, PartialEq)]
1156struct ScrollPosition {
1157 offset_before_cursor: Vector2F,
1158 cursor: Anchor,
1159}
1160
1161struct AssistantEditor {
1162 assistant: ModelHandle<Assistant>,
1163 fs: Arc<dyn Fs>,
1164 editor: ViewHandle<Editor>,
1165 blocks: HashSet<BlockId>,
1166 scroll_position: Option<ScrollPosition>,
1167 _subscriptions: Vec<Subscription>,
1168}
1169
1170impl AssistantEditor {
1171 fn new(
1172 api_key: Rc<RefCell<Option<String>>>,
1173 language_registry: Arc<LanguageRegistry>,
1174 fs: Arc<dyn Fs>,
1175 cx: &mut ViewContext<Self>,
1176 ) -> Self {
1177 let assistant = cx.add_model(|cx| Assistant::new(api_key, language_registry, cx));
1178 let editor = cx.add_view(|cx| {
1179 let mut editor = Editor::for_buffer(assistant.read(cx).buffer.clone(), None, cx);
1180 editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
1181 editor.set_show_gutter(false, cx);
1182 editor
1183 });
1184
1185 let _subscriptions = vec![
1186 cx.observe(&assistant, |_, _, cx| cx.notify()),
1187 cx.subscribe(&assistant, Self::handle_assistant_event),
1188 cx.subscribe(&editor, Self::handle_editor_event),
1189 ];
1190
1191 let mut this = Self {
1192 assistant,
1193 editor,
1194 blocks: Default::default(),
1195 scroll_position: None,
1196 fs,
1197 _subscriptions,
1198 };
1199 this.update_message_headers(cx);
1200 this
1201 }
1202
1203 fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
1204 let cursors = self.cursors(cx);
1205
1206 let user_messages = self.assistant.update(cx, |assistant, cx| {
1207 let selected_messages = assistant
1208 .messages_for_offsets(cursors, cx)
1209 .into_iter()
1210 .map(|message| message.id)
1211 .collect();
1212 assistant.assist(selected_messages, cx)
1213 });
1214 let new_selections = user_messages
1215 .iter()
1216 .map(|message| {
1217 let cursor = message
1218 .start
1219 .to_offset(self.assistant.read(cx).buffer.read(cx));
1220 cursor..cursor
1221 })
1222 .collect::<Vec<_>>();
1223 if !new_selections.is_empty() {
1224 self.editor.update(cx, |editor, cx| {
1225 editor.change_selections(
1226 Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
1227 cx,
1228 |selections| selections.select_ranges(new_selections),
1229 );
1230 });
1231 }
1232 }
1233
1234 fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
1235 if !self
1236 .assistant
1237 .update(cx, |assistant, _| assistant.cancel_last_assist())
1238 {
1239 cx.propagate_action();
1240 }
1241 }
1242
1243 fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
1244 let cursors = self.cursors(cx);
1245 self.assistant.update(cx, |assistant, cx| {
1246 let messages = assistant
1247 .messages_for_offsets(cursors, cx)
1248 .into_iter()
1249 .map(|message| message.id)
1250 .collect();
1251 assistant.cycle_message_roles(messages, cx)
1252 });
1253 }
1254
1255 fn cursors(&self, cx: &AppContext) -> Vec<usize> {
1256 let selections = self.editor.read(cx).selections.all::<usize>(cx);
1257 selections
1258 .into_iter()
1259 .map(|selection| selection.head())
1260 .collect()
1261 }
1262
1263 fn handle_assistant_event(
1264 &mut self,
1265 _: ModelHandle<Assistant>,
1266 event: &AssistantEvent,
1267 cx: &mut ViewContext<Self>,
1268 ) {
1269 match event {
1270 AssistantEvent::MessagesEdited => {
1271 self.update_message_headers(cx);
1272 self.assistant.update(cx, |assistant, cx| {
1273 assistant.save(Some(Duration::from_millis(500)), self.fs.clone(), cx);
1274 });
1275 }
1276 AssistantEvent::SummaryChanged => {
1277 cx.emit(AssistantEditorEvent::TabContentChanged);
1278 self.assistant.update(cx, |assistant, cx| {
1279 assistant.save(None, self.fs.clone(), cx);
1280 });
1281 }
1282 AssistantEvent::StreamedCompletion => {
1283 self.editor.update(cx, |editor, cx| {
1284 if let Some(scroll_position) = self.scroll_position {
1285 let snapshot = editor.snapshot(cx);
1286 let cursor_point = scroll_position.cursor.to_display_point(&snapshot);
1287 let scroll_top =
1288 cursor_point.row() as f32 - scroll_position.offset_before_cursor.y();
1289 editor.set_scroll_position(
1290 vec2f(scroll_position.offset_before_cursor.x(), scroll_top),
1291 cx,
1292 );
1293 }
1294 });
1295 }
1296 }
1297 }
1298
1299 fn handle_editor_event(
1300 &mut self,
1301 _: ViewHandle<Editor>,
1302 event: &editor::Event,
1303 cx: &mut ViewContext<Self>,
1304 ) {
1305 match event {
1306 editor::Event::ScrollPositionChanged { autoscroll, .. } => {
1307 let cursor_scroll_position = self.cursor_scroll_position(cx);
1308 if *autoscroll {
1309 self.scroll_position = cursor_scroll_position;
1310 } else if self.scroll_position != cursor_scroll_position {
1311 self.scroll_position = None;
1312 }
1313 }
1314 editor::Event::SelectionsChanged { .. } => {
1315 self.scroll_position = self.cursor_scroll_position(cx);
1316 }
1317 _ => {}
1318 }
1319 }
1320
1321 fn cursor_scroll_position(&self, cx: &mut ViewContext<Self>) -> Option<ScrollPosition> {
1322 self.editor.update(cx, |editor, cx| {
1323 let snapshot = editor.snapshot(cx);
1324 let cursor = editor.selections.newest_anchor().head();
1325 let cursor_row = cursor.to_display_point(&snapshot.display_snapshot).row() as f32;
1326 let scroll_position = editor
1327 .scroll_manager
1328 .anchor()
1329 .scroll_position(&snapshot.display_snapshot);
1330
1331 let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1332 if (scroll_position.y()..scroll_bottom).contains(&cursor_row) {
1333 Some(ScrollPosition {
1334 cursor,
1335 offset_before_cursor: vec2f(
1336 scroll_position.x(),
1337 cursor_row - scroll_position.y(),
1338 ),
1339 })
1340 } else {
1341 None
1342 }
1343 })
1344 }
1345
1346 fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
1347 self.editor.update(cx, |editor, cx| {
1348 let buffer = editor.buffer().read(cx).snapshot(cx);
1349 let excerpt_id = *buffer.as_singleton().unwrap().0;
1350 let old_blocks = std::mem::take(&mut self.blocks);
1351 let new_blocks = self
1352 .assistant
1353 .read(cx)
1354 .messages(cx)
1355 .map(|message| BlockProperties {
1356 position: buffer.anchor_in_excerpt(excerpt_id, message.anchor),
1357 height: 2,
1358 style: BlockStyle::Sticky,
1359 render: Arc::new({
1360 let assistant = self.assistant.clone();
1361 // let metadata = message.metadata.clone();
1362 // let message = message.clone();
1363 move |cx| {
1364 enum Sender {}
1365 enum ErrorTooltip {}
1366
1367 let theme = theme::current(cx);
1368 let style = &theme.assistant;
1369 let message_id = message.id;
1370 let sender = MouseEventHandler::<Sender, _>::new(
1371 message_id.0,
1372 cx,
1373 |state, _| match message.role {
1374 Role::User => {
1375 let style = style.user_sender.style_for(state, false);
1376 Label::new("You", style.text.clone())
1377 .contained()
1378 .with_style(style.container)
1379 }
1380 Role::Assistant => {
1381 let style = style.assistant_sender.style_for(state, false);
1382 Label::new("Assistant", style.text.clone())
1383 .contained()
1384 .with_style(style.container)
1385 }
1386 Role::System => {
1387 let style = style.system_sender.style_for(state, false);
1388 Label::new("System", style.text.clone())
1389 .contained()
1390 .with_style(style.container)
1391 }
1392 },
1393 )
1394 .with_cursor_style(CursorStyle::PointingHand)
1395 .on_down(MouseButton::Left, {
1396 let assistant = assistant.clone();
1397 move |_, _, cx| {
1398 assistant.update(cx, |assistant, cx| {
1399 assistant.cycle_message_roles(
1400 HashSet::from_iter(Some(message_id)),
1401 cx,
1402 )
1403 })
1404 }
1405 });
1406
1407 Flex::row()
1408 .with_child(sender.aligned())
1409 .with_child(
1410 Label::new(
1411 message.sent_at.format("%I:%M%P").to_string(),
1412 style.sent_at.text.clone(),
1413 )
1414 .contained()
1415 .with_style(style.sent_at.container)
1416 .aligned(),
1417 )
1418 .with_children(
1419 if let MessageStatus::Error(error) = &message.status {
1420 Some(
1421 Svg::new("icons/circle_x_mark_12.svg")
1422 .with_color(style.error_icon.color)
1423 .constrained()
1424 .with_width(style.error_icon.width)
1425 .contained()
1426 .with_style(style.error_icon.container)
1427 .with_tooltip::<ErrorTooltip>(
1428 message_id.0,
1429 error.to_string(),
1430 None,
1431 theme.tooltip.clone(),
1432 cx,
1433 )
1434 .aligned(),
1435 )
1436 } else {
1437 None
1438 },
1439 )
1440 .aligned()
1441 .left()
1442 .contained()
1443 .with_style(style.header)
1444 .into_any()
1445 }
1446 }),
1447 disposition: BlockDisposition::Above,
1448 })
1449 .collect::<Vec<_>>();
1450
1451 editor.remove_blocks(old_blocks, None, cx);
1452 let ids = editor.insert_blocks(new_blocks, None, cx);
1453 self.blocks = HashSet::from_iter(ids);
1454 });
1455 }
1456
1457 fn quote_selection(
1458 workspace: &mut Workspace,
1459 _: &QuoteSelection,
1460 cx: &mut ViewContext<Workspace>,
1461 ) {
1462 let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1463 return;
1464 };
1465 let Some(editor) = workspace.active_item(cx).and_then(|item| item.downcast::<Editor>()) else {
1466 return;
1467 };
1468
1469 let text = editor.read_with(cx, |editor, cx| {
1470 let range = editor.selections.newest::<usize>(cx).range();
1471 let buffer = editor.buffer().read(cx).snapshot(cx);
1472 let start_language = buffer.language_at(range.start);
1473 let end_language = buffer.language_at(range.end);
1474 let language_name = if start_language == end_language {
1475 start_language.map(|language| language.name())
1476 } else {
1477 None
1478 };
1479 let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1480
1481 let selected_text = buffer.text_for_range(range).collect::<String>();
1482 if selected_text.is_empty() {
1483 None
1484 } else {
1485 Some(if language_name == "markdown" {
1486 selected_text
1487 .lines()
1488 .map(|line| format!("> {}", line))
1489 .collect::<Vec<_>>()
1490 .join("\n")
1491 } else {
1492 format!("```{language_name}\n{selected_text}\n```")
1493 })
1494 }
1495 });
1496
1497 // Activate the panel
1498 if !panel.read(cx).has_focus(cx) {
1499 workspace.toggle_panel_focus::<AssistantPanel>(cx);
1500 }
1501
1502 if let Some(text) = text {
1503 panel.update(cx, |panel, cx| {
1504 if let Some(assistant) = panel
1505 .pane
1506 .read(cx)
1507 .active_item()
1508 .and_then(|item| item.downcast::<AssistantEditor>())
1509 .ok_or_else(|| anyhow!("no active context"))
1510 .log_err()
1511 {
1512 assistant.update(cx, |assistant, cx| {
1513 assistant
1514 .editor
1515 .update(cx, |editor, cx| editor.insert(&text, cx))
1516 });
1517 }
1518 });
1519 }
1520 }
1521
1522 fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1523 let editor = self.editor.read(cx);
1524 let assistant = self.assistant.read(cx);
1525 if editor.selections.count() == 1 {
1526 let selection = editor.selections.newest::<usize>(cx);
1527 let mut copied_text = String::new();
1528 let mut spanned_messages = 0;
1529 for message in assistant.messages(cx) {
1530 if message.range.start >= selection.range().end {
1531 break;
1532 } else if message.range.end >= selection.range().start {
1533 let range = cmp::max(message.range.start, selection.range().start)
1534 ..cmp::min(message.range.end, selection.range().end);
1535 if !range.is_empty() {
1536 spanned_messages += 1;
1537 write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
1538 for chunk in assistant.buffer.read(cx).text_for_range(range) {
1539 copied_text.push_str(&chunk);
1540 }
1541 copied_text.push('\n');
1542 }
1543 }
1544 }
1545
1546 if spanned_messages > 1 {
1547 cx.platform()
1548 .write_to_clipboard(ClipboardItem::new(copied_text));
1549 return;
1550 }
1551 }
1552
1553 cx.propagate_action();
1554 }
1555
1556 fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
1557 self.assistant.update(cx, |assistant, cx| {
1558 let selections = self.editor.read(cx).selections.disjoint_anchors();
1559 for selection in selections.into_iter() {
1560 let buffer = self.editor.read(cx).buffer().read(cx).snapshot(cx);
1561 let range = selection
1562 .map(|endpoint| endpoint.to_offset(&buffer))
1563 .range();
1564 assistant.split_message(range, cx);
1565 }
1566 });
1567 }
1568
1569 fn save(&mut self, _: &Save, cx: &mut ViewContext<Self>) {
1570 self.assistant.update(cx, |assistant, cx| {
1571 assistant.save(None, self.fs.clone(), cx)
1572 });
1573 }
1574
1575 fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
1576 self.assistant.update(cx, |assistant, cx| {
1577 let new_model = match assistant.model.as_str() {
1578 "gpt-4-0613" => "gpt-3.5-turbo-0613",
1579 _ => "gpt-4-0613",
1580 };
1581 assistant.set_model(new_model.into(), cx);
1582 });
1583 }
1584
1585 fn title(&self, cx: &AppContext) -> String {
1586 self.assistant
1587 .read(cx)
1588 .summary
1589 .as_ref()
1590 .map(|summary| summary.text.clone())
1591 .unwrap_or_else(|| "New Context".into())
1592 }
1593}
1594
1595impl Entity for AssistantEditor {
1596 type Event = AssistantEditorEvent;
1597}
1598
1599impl View for AssistantEditor {
1600 fn ui_name() -> &'static str {
1601 "AssistantEditor"
1602 }
1603
1604 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
1605 enum Model {}
1606 let theme = &theme::current(cx).assistant;
1607 let assistant = &self.assistant.read(cx);
1608 let model = assistant.model.clone();
1609 let remaining_tokens = assistant.remaining_tokens().map(|remaining_tokens| {
1610 let remaining_tokens_style = if remaining_tokens <= 0 {
1611 &theme.no_remaining_tokens
1612 } else {
1613 &theme.remaining_tokens
1614 };
1615 Label::new(
1616 remaining_tokens.to_string(),
1617 remaining_tokens_style.text.clone(),
1618 )
1619 .contained()
1620 .with_style(remaining_tokens_style.container)
1621 });
1622
1623 Stack::new()
1624 .with_child(
1625 ChildView::new(&self.editor, cx)
1626 .contained()
1627 .with_style(theme.container),
1628 )
1629 .with_child(
1630 Flex::row()
1631 .with_child(
1632 MouseEventHandler::<Model, _>::new(0, cx, |state, _| {
1633 let style = theme.model.style_for(state, false);
1634 Label::new(model, style.text.clone())
1635 .contained()
1636 .with_style(style.container)
1637 })
1638 .with_cursor_style(CursorStyle::PointingHand)
1639 .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx)),
1640 )
1641 .with_children(remaining_tokens)
1642 .contained()
1643 .with_style(theme.model_info_container)
1644 .aligned()
1645 .top()
1646 .right(),
1647 )
1648 .into_any()
1649 }
1650
1651 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
1652 if cx.is_self_focused() {
1653 cx.focus(&self.editor);
1654 }
1655 }
1656}
1657
1658impl Item for AssistantEditor {
1659 fn tab_content<V: View>(
1660 &self,
1661 _: Option<usize>,
1662 style: &theme::Tab,
1663 cx: &gpui::AppContext,
1664 ) -> AnyElement<V> {
1665 let title = truncate_and_trailoff(&self.title(cx), editor::MAX_TAB_TITLE_LEN);
1666 Label::new(title, style.label.clone()).into_any()
1667 }
1668
1669 fn tab_tooltip_text(&self, cx: &AppContext) -> Option<Cow<str>> {
1670 Some(self.title(cx).into())
1671 }
1672
1673 fn as_searchable(
1674 &self,
1675 _: &ViewHandle<Self>,
1676 ) -> Option<Box<dyn workspace::searchable::SearchableItemHandle>> {
1677 Some(Box::new(self.editor.clone()))
1678 }
1679}
1680
1681#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash)]
1682struct MessageId(usize);
1683
1684#[derive(Clone, Debug)]
1685struct MessageAnchor {
1686 id: MessageId,
1687 start: language::Anchor,
1688}
1689
1690#[derive(Clone, Debug)]
1691struct MessageMetadata {
1692 role: Role,
1693 sent_at: DateTime<Local>,
1694 status: MessageStatus,
1695}
1696
1697#[derive(Clone, Debug)]
1698enum MessageStatus {
1699 Pending,
1700 Done,
1701 Error(Arc<str>),
1702}
1703
1704#[derive(Clone, Debug)]
1705pub struct Message {
1706 range: Range<usize>,
1707 index: usize,
1708 id: MessageId,
1709 anchor: language::Anchor,
1710 role: Role,
1711 sent_at: DateTime<Local>,
1712 status: MessageStatus,
1713}
1714
1715impl Message {
1716 fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
1717 let mut content = format!("[Message {}]\n", self.id.0).to_string();
1718 content.extend(buffer.text_for_range(self.range.clone()));
1719 RequestMessage {
1720 role: self.role,
1721 content,
1722 }
1723 }
1724}
1725
1726async fn stream_completion(
1727 api_key: String,
1728 executor: Arc<Background>,
1729 mut request: OpenAIRequest,
1730) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
1731 request.stream = true;
1732
1733 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
1734
1735 let json_data = serde_json::to_string(&request)?;
1736 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
1737 .header("Content-Type", "application/json")
1738 .header("Authorization", format!("Bearer {}", api_key))
1739 .body(json_data)?
1740 .send_async()
1741 .await?;
1742
1743 let status = response.status();
1744 if status == StatusCode::OK {
1745 executor
1746 .spawn(async move {
1747 let mut lines = BufReader::new(response.body_mut()).lines();
1748
1749 fn parse_line(
1750 line: Result<String, io::Error>,
1751 ) -> Result<Option<OpenAIResponseStreamEvent>> {
1752 if let Some(data) = line?.strip_prefix("data: ") {
1753 let event = serde_json::from_str(&data)?;
1754 Ok(Some(event))
1755 } else {
1756 Ok(None)
1757 }
1758 }
1759
1760 while let Some(line) = lines.next().await {
1761 if let Some(event) = parse_line(line).transpose() {
1762 let done = event.as_ref().map_or(false, |event| {
1763 event
1764 .choices
1765 .last()
1766 .map_or(false, |choice| choice.finish_reason.is_some())
1767 });
1768 if tx.unbounded_send(event).is_err() {
1769 break;
1770 }
1771
1772 if done {
1773 break;
1774 }
1775 }
1776 }
1777
1778 anyhow::Ok(())
1779 })
1780 .detach();
1781
1782 Ok(rx)
1783 } else {
1784 let mut body = String::new();
1785 response.body_mut().read_to_string(&mut body).await?;
1786
1787 #[derive(Deserialize)]
1788 struct OpenAIResponse {
1789 error: OpenAIError,
1790 }
1791
1792 #[derive(Deserialize)]
1793 struct OpenAIError {
1794 message: String,
1795 }
1796
1797 match serde_json::from_str::<OpenAIResponse>(&body) {
1798 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
1799 "Failed to connect to OpenAI API: {}",
1800 response.error.message,
1801 )),
1802
1803 _ => Err(anyhow!(
1804 "Failed to connect to OpenAI API: {} {}",
1805 response.status(),
1806 body,
1807 )),
1808 }
1809 }
1810}
1811
1812#[cfg(test)]
1813mod tests {
1814 use super::*;
1815 use gpui::AppContext;
1816
1817 #[gpui::test]
1818 fn test_inserting_and_removing_messages(cx: &mut AppContext) {
1819 let registry = Arc::new(LanguageRegistry::test());
1820 let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
1821 let buffer = assistant.read(cx).buffer.clone();
1822
1823 let message_1 = assistant.read(cx).message_anchors[0].clone();
1824 assert_eq!(
1825 messages(&assistant, cx),
1826 vec![(message_1.id, Role::User, 0..0)]
1827 );
1828
1829 let message_2 = assistant.update(cx, |assistant, cx| {
1830 assistant
1831 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
1832 .unwrap()
1833 });
1834 assert_eq!(
1835 messages(&assistant, cx),
1836 vec![
1837 (message_1.id, Role::User, 0..1),
1838 (message_2.id, Role::Assistant, 1..1)
1839 ]
1840 );
1841
1842 buffer.update(cx, |buffer, cx| {
1843 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
1844 });
1845 assert_eq!(
1846 messages(&assistant, cx),
1847 vec![
1848 (message_1.id, Role::User, 0..2),
1849 (message_2.id, Role::Assistant, 2..3)
1850 ]
1851 );
1852
1853 let message_3 = assistant.update(cx, |assistant, cx| {
1854 assistant
1855 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
1856 .unwrap()
1857 });
1858 assert_eq!(
1859 messages(&assistant, cx),
1860 vec![
1861 (message_1.id, Role::User, 0..2),
1862 (message_2.id, Role::Assistant, 2..4),
1863 (message_3.id, Role::User, 4..4)
1864 ]
1865 );
1866
1867 let message_4 = assistant.update(cx, |assistant, cx| {
1868 assistant
1869 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
1870 .unwrap()
1871 });
1872 assert_eq!(
1873 messages(&assistant, cx),
1874 vec![
1875 (message_1.id, Role::User, 0..2),
1876 (message_2.id, Role::Assistant, 2..4),
1877 (message_4.id, Role::User, 4..5),
1878 (message_3.id, Role::User, 5..5),
1879 ]
1880 );
1881
1882 buffer.update(cx, |buffer, cx| {
1883 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
1884 });
1885 assert_eq!(
1886 messages(&assistant, cx),
1887 vec![
1888 (message_1.id, Role::User, 0..2),
1889 (message_2.id, Role::Assistant, 2..4),
1890 (message_4.id, Role::User, 4..6),
1891 (message_3.id, Role::User, 6..7),
1892 ]
1893 );
1894
1895 // Deleting across message boundaries merges the messages.
1896 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
1897 assert_eq!(
1898 messages(&assistant, cx),
1899 vec![
1900 (message_1.id, Role::User, 0..3),
1901 (message_3.id, Role::User, 3..4),
1902 ]
1903 );
1904
1905 // Undoing the deletion should also undo the merge.
1906 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1907 assert_eq!(
1908 messages(&assistant, cx),
1909 vec![
1910 (message_1.id, Role::User, 0..2),
1911 (message_2.id, Role::Assistant, 2..4),
1912 (message_4.id, Role::User, 4..6),
1913 (message_3.id, Role::User, 6..7),
1914 ]
1915 );
1916
1917 // Redoing the deletion should also redo the merge.
1918 buffer.update(cx, |buffer, cx| buffer.redo(cx));
1919 assert_eq!(
1920 messages(&assistant, cx),
1921 vec![
1922 (message_1.id, Role::User, 0..3),
1923 (message_3.id, Role::User, 3..4),
1924 ]
1925 );
1926
1927 // Ensure we can still insert after a merged message.
1928 let message_5 = assistant.update(cx, |assistant, cx| {
1929 assistant
1930 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
1931 .unwrap()
1932 });
1933 assert_eq!(
1934 messages(&assistant, cx),
1935 vec![
1936 (message_1.id, Role::User, 0..3),
1937 (message_5.id, Role::System, 3..4),
1938 (message_3.id, Role::User, 4..5)
1939 ]
1940 );
1941 }
1942
1943 #[gpui::test]
1944 fn test_message_splitting(cx: &mut AppContext) {
1945 let registry = Arc::new(LanguageRegistry::test());
1946 let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
1947 let buffer = assistant.read(cx).buffer.clone();
1948
1949 let message_1 = assistant.read(cx).message_anchors[0].clone();
1950 assert_eq!(
1951 messages(&assistant, cx),
1952 vec![(message_1.id, Role::User, 0..0)]
1953 );
1954
1955 buffer.update(cx, |buffer, cx| {
1956 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
1957 });
1958
1959 let (_, message_2) =
1960 assistant.update(cx, |assistant, cx| assistant.split_message(3..3, cx));
1961 let message_2 = message_2.unwrap();
1962
1963 // We recycle newlines in the middle of a split message
1964 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
1965 assert_eq!(
1966 messages(&assistant, cx),
1967 vec![
1968 (message_1.id, Role::User, 0..4),
1969 (message_2.id, Role::User, 4..16),
1970 ]
1971 );
1972
1973 let (_, message_3) =
1974 assistant.update(cx, |assistant, cx| assistant.split_message(3..3, cx));
1975 let message_3 = message_3.unwrap();
1976
1977 // We don't recycle newlines at the end of a split message
1978 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
1979 assert_eq!(
1980 messages(&assistant, cx),
1981 vec![
1982 (message_1.id, Role::User, 0..4),
1983 (message_3.id, Role::User, 4..5),
1984 (message_2.id, Role::User, 5..17),
1985 ]
1986 );
1987
1988 let (_, message_4) =
1989 assistant.update(cx, |assistant, cx| assistant.split_message(9..9, cx));
1990 let message_4 = message_4.unwrap();
1991 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
1992 assert_eq!(
1993 messages(&assistant, cx),
1994 vec![
1995 (message_1.id, Role::User, 0..4),
1996 (message_3.id, Role::User, 4..5),
1997 (message_2.id, Role::User, 5..9),
1998 (message_4.id, Role::User, 9..17),
1999 ]
2000 );
2001
2002 let (_, message_5) =
2003 assistant.update(cx, |assistant, cx| assistant.split_message(9..9, cx));
2004 let message_5 = message_5.unwrap();
2005 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
2006 assert_eq!(
2007 messages(&assistant, cx),
2008 vec![
2009 (message_1.id, Role::User, 0..4),
2010 (message_3.id, Role::User, 4..5),
2011 (message_2.id, Role::User, 5..9),
2012 (message_4.id, Role::User, 9..10),
2013 (message_5.id, Role::User, 10..18),
2014 ]
2015 );
2016
2017 let (message_6, message_7) =
2018 assistant.update(cx, |assistant, cx| assistant.split_message(14..16, cx));
2019 let message_6 = message_6.unwrap();
2020 let message_7 = message_7.unwrap();
2021 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
2022 assert_eq!(
2023 messages(&assistant, cx),
2024 vec![
2025 (message_1.id, Role::User, 0..4),
2026 (message_3.id, Role::User, 4..5),
2027 (message_2.id, Role::User, 5..9),
2028 (message_4.id, Role::User, 9..10),
2029 (message_5.id, Role::User, 10..14),
2030 (message_6.id, Role::User, 14..17),
2031 (message_7.id, Role::User, 17..19),
2032 ]
2033 );
2034 }
2035
2036 #[gpui::test]
2037 fn test_messages_for_offsets(cx: &mut AppContext) {
2038 let registry = Arc::new(LanguageRegistry::test());
2039 let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
2040 let buffer = assistant.read(cx).buffer.clone();
2041
2042 let message_1 = assistant.read(cx).message_anchors[0].clone();
2043 assert_eq!(
2044 messages(&assistant, cx),
2045 vec![(message_1.id, Role::User, 0..0)]
2046 );
2047
2048 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
2049 let message_2 = assistant
2050 .update(cx, |assistant, cx| {
2051 assistant.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
2052 })
2053 .unwrap();
2054 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
2055
2056 let message_3 = assistant
2057 .update(cx, |assistant, cx| {
2058 assistant.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2059 })
2060 .unwrap();
2061 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
2062
2063 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
2064 assert_eq!(
2065 messages(&assistant, cx),
2066 vec![
2067 (message_1.id, Role::User, 0..4),
2068 (message_2.id, Role::User, 4..8),
2069 (message_3.id, Role::User, 8..11)
2070 ]
2071 );
2072
2073 assert_eq!(
2074 message_ids_for_offsets(&assistant, &[0, 4, 9], cx),
2075 [message_1.id, message_2.id, message_3.id]
2076 );
2077 assert_eq!(
2078 message_ids_for_offsets(&assistant, &[0, 1, 11], cx),
2079 [message_1.id, message_3.id]
2080 );
2081
2082 fn message_ids_for_offsets(
2083 assistant: &ModelHandle<Assistant>,
2084 offsets: &[usize],
2085 cx: &AppContext,
2086 ) -> Vec<MessageId> {
2087 assistant
2088 .read(cx)
2089 .messages_for_offsets(offsets.iter().copied(), cx)
2090 .into_iter()
2091 .map(|message| message.id)
2092 .collect()
2093 }
2094 }
2095
2096 fn messages(
2097 assistant: &ModelHandle<Assistant>,
2098 cx: &AppContext,
2099 ) -> Vec<(MessageId, Role, Range<usize>)> {
2100 assistant
2101 .read(cx)
2102 .messages(cx)
2103 .map(|message| (message.id, message.role, message.range))
2104 .collect()
2105 }
2106}