1use crate::{
2 assistant_settings::{AssistantDockPosition, AssistantSettings},
3 OpenAIRequest, OpenAIResponseStreamEvent, RequestMessage, Role,
4};
5use anyhow::{anyhow, Result};
6use chrono::{DateTime, Local};
7use collections::{HashMap, HashSet};
8use editor::{
9 display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
10 scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
11 Anchor, Editor, ToOffset,
12};
13use fs::Fs;
14use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
15use gpui::{
16 actions,
17 elements::*,
18 executor::Background,
19 geometry::vector::{vec2f, Vector2F},
20 platform::{CursorStyle, MouseButton},
21 Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
22 Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
23};
24use isahc::{http::StatusCode, Request, RequestExt};
25use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
26use serde::Deserialize;
27use settings::SettingsStore;
28use std::{
29 borrow::Cow, cell::RefCell, cmp, fmt::Write, io, iter, ops::Range, rc::Rc, sync::Arc,
30 time::Duration,
31};
32use util::{channel::ReleaseChannel, post_inc, truncate_and_trailoff, ResultExt, TryFutureExt};
33use workspace::{
34 dock::{DockPosition, Panel},
35 item::Item,
36 pane, Pane, Workspace,
37};
38
39const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
40
41actions!(
42 assistant,
43 [
44 NewContext,
45 Assist,
46 Split,
47 CycleMessageRole,
48 QuoteSelection,
49 ToggleFocus,
50 ResetKey
51 ]
52);
53
54pub fn init(cx: &mut AppContext) {
55 if *util::channel::RELEASE_CHANNEL == ReleaseChannel::Stable {
56 cx.update_default_global::<collections::CommandPaletteFilter, _, _>(move |filter, _cx| {
57 filter.filtered_namespaces.insert("assistant");
58 });
59 }
60
61 settings::register::<AssistantSettings>(cx);
62 cx.add_action(
63 |workspace: &mut Workspace, _: &NewContext, cx: &mut ViewContext<Workspace>| {
64 if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
65 this.update(cx, |this, cx| this.add_context(cx))
66 }
67
68 workspace.focus_panel::<AssistantPanel>(cx);
69 },
70 );
71 cx.add_action(AssistantEditor::assist);
72 cx.capture_action(AssistantEditor::cancel_last_assist);
73 cx.add_action(AssistantEditor::quote_selection);
74 cx.capture_action(AssistantEditor::copy);
75 cx.capture_action(AssistantEditor::split);
76 cx.capture_action(AssistantEditor::cycle_message_role);
77 cx.add_action(AssistantPanel::save_api_key);
78 cx.add_action(AssistantPanel::reset_api_key);
79 cx.add_action(
80 |workspace: &mut Workspace, _: &ToggleFocus, cx: &mut ViewContext<Workspace>| {
81 workspace.toggle_panel_focus::<AssistantPanel>(cx);
82 },
83 );
84}
85
86pub enum AssistantPanelEvent {
87 ZoomIn,
88 ZoomOut,
89 Focus,
90 Close,
91 DockPositionChanged,
92}
93
94pub struct AssistantPanel {
95 width: Option<f32>,
96 height: Option<f32>,
97 pane: ViewHandle<Pane>,
98 api_key: Rc<RefCell<Option<String>>>,
99 api_key_editor: Option<ViewHandle<Editor>>,
100 has_read_credentials: bool,
101 languages: Arc<LanguageRegistry>,
102 fs: Arc<dyn Fs>,
103 subscriptions: Vec<Subscription>,
104}
105
106impl AssistantPanel {
107 pub fn load(
108 workspace: WeakViewHandle<Workspace>,
109 cx: AsyncAppContext,
110 ) -> Task<Result<ViewHandle<Self>>> {
111 cx.spawn(|mut cx| async move {
112 // TODO: deserialize state.
113 workspace.update(&mut cx, |workspace, cx| {
114 cx.add_view::<Self, _>(|cx| {
115 let weak_self = cx.weak_handle();
116 let pane = cx.add_view(|cx| {
117 let mut pane = Pane::new(
118 workspace.weak_handle(),
119 workspace.project().clone(),
120 workspace.app_state().background_actions,
121 Default::default(),
122 cx,
123 );
124 pane.set_can_split(false, cx);
125 pane.set_can_navigate(false, cx);
126 pane.on_can_drop(move |_, _| false);
127 pane.set_render_tab_bar_buttons(cx, move |pane, cx| {
128 let weak_self = weak_self.clone();
129 Flex::row()
130 .with_child(Pane::render_tab_bar_button(
131 0,
132 "icons/plus_12.svg",
133 false,
134 Some(("New Context".into(), Some(Box::new(NewContext)))),
135 cx,
136 move |_, cx| {
137 let weak_self = weak_self.clone();
138 cx.window_context().defer(move |cx| {
139 if let Some(this) = weak_self.upgrade(cx) {
140 this.update(cx, |this, cx| this.add_context(cx));
141 }
142 })
143 },
144 None,
145 ))
146 .with_child(Pane::render_tab_bar_button(
147 1,
148 if pane.is_zoomed() {
149 "icons/minimize_8.svg"
150 } else {
151 "icons/maximize_8.svg"
152 },
153 pane.is_zoomed(),
154 Some((
155 "Toggle Zoom".into(),
156 Some(Box::new(workspace::ToggleZoom)),
157 )),
158 cx,
159 move |pane, cx| pane.toggle_zoom(&Default::default(), cx),
160 None,
161 ))
162 .into_any()
163 });
164 let buffer_search_bar = cx.add_view(search::BufferSearchBar::new);
165 pane.toolbar()
166 .update(cx, |toolbar, cx| toolbar.add_item(buffer_search_bar, cx));
167 pane
168 });
169
170 let mut this = Self {
171 pane,
172 api_key: Rc::new(RefCell::new(None)),
173 api_key_editor: None,
174 has_read_credentials: false,
175 languages: workspace.app_state().languages.clone(),
176 fs: workspace.app_state().fs.clone(),
177 width: None,
178 height: None,
179 subscriptions: Default::default(),
180 };
181
182 let mut old_dock_position = this.position(cx);
183 this.subscriptions = vec![
184 cx.observe(&this.pane, |_, _, cx| cx.notify()),
185 cx.subscribe(&this.pane, Self::handle_pane_event),
186 cx.observe_global::<SettingsStore, _>(move |this, cx| {
187 let new_dock_position = this.position(cx);
188 if new_dock_position != old_dock_position {
189 old_dock_position = new_dock_position;
190 cx.emit(AssistantPanelEvent::DockPositionChanged);
191 }
192 }),
193 ];
194
195 this
196 })
197 })
198 })
199 }
200
201 fn handle_pane_event(
202 &mut self,
203 _pane: ViewHandle<Pane>,
204 event: &pane::Event,
205 cx: &mut ViewContext<Self>,
206 ) {
207 match event {
208 pane::Event::ZoomIn => cx.emit(AssistantPanelEvent::ZoomIn),
209 pane::Event::ZoomOut => cx.emit(AssistantPanelEvent::ZoomOut),
210 pane::Event::Focus => cx.emit(AssistantPanelEvent::Focus),
211 pane::Event::Remove => cx.emit(AssistantPanelEvent::Close),
212 _ => {}
213 }
214 }
215
216 fn add_context(&mut self, cx: &mut ViewContext<Self>) {
217 let focus = self.has_focus(cx);
218 let editor = cx
219 .add_view(|cx| AssistantEditor::new(self.api_key.clone(), self.languages.clone(), cx));
220 self.subscriptions
221 .push(cx.subscribe(&editor, Self::handle_assistant_editor_event));
222 self.pane.update(cx, |pane, cx| {
223 pane.add_item(Box::new(editor), true, focus, None, cx)
224 });
225 }
226
227 fn handle_assistant_editor_event(
228 &mut self,
229 _: ViewHandle<AssistantEditor>,
230 event: &AssistantEditorEvent,
231 cx: &mut ViewContext<Self>,
232 ) {
233 match event {
234 AssistantEditorEvent::TabContentChanged => self.pane.update(cx, |_, cx| cx.notify()),
235 }
236 }
237
238 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
239 if let Some(api_key) = self
240 .api_key_editor
241 .as_ref()
242 .map(|editor| editor.read(cx).text(cx))
243 {
244 if !api_key.is_empty() {
245 cx.platform()
246 .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
247 .log_err();
248 *self.api_key.borrow_mut() = Some(api_key);
249 self.api_key_editor.take();
250 cx.focus_self();
251 cx.notify();
252 }
253 } else {
254 cx.propagate_action();
255 }
256 }
257
258 fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
259 cx.platform().delete_credentials(OPENAI_API_URL).log_err();
260 self.api_key.take();
261 self.api_key_editor = Some(build_api_key_editor(cx));
262 cx.focus_self();
263 cx.notify();
264 }
265}
266
267fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
268 cx.add_view(|cx| {
269 let mut editor = Editor::single_line(
270 Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
271 cx,
272 );
273 editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
274 editor
275 })
276}
277
278impl Entity for AssistantPanel {
279 type Event = AssistantPanelEvent;
280}
281
282impl View for AssistantPanel {
283 fn ui_name() -> &'static str {
284 "AssistantPanel"
285 }
286
287 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
288 let style = &theme::current(cx).assistant;
289 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
290 Flex::column()
291 .with_child(
292 Text::new(
293 "Paste your OpenAI API key and press Enter to use the assistant",
294 style.api_key_prompt.text.clone(),
295 )
296 .aligned(),
297 )
298 .with_child(
299 ChildView::new(api_key_editor, cx)
300 .contained()
301 .with_style(style.api_key_editor.container)
302 .aligned(),
303 )
304 .contained()
305 .with_style(style.api_key_prompt.container)
306 .aligned()
307 .into_any()
308 } else {
309 ChildView::new(&self.pane, cx).into_any()
310 }
311 }
312
313 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
314 if cx.is_self_focused() {
315 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
316 cx.focus(api_key_editor);
317 } else {
318 cx.focus(&self.pane);
319 }
320 }
321 }
322}
323
324impl Panel for AssistantPanel {
325 fn position(&self, cx: &WindowContext) -> DockPosition {
326 match settings::get::<AssistantSettings>(cx).dock {
327 AssistantDockPosition::Left => DockPosition::Left,
328 AssistantDockPosition::Bottom => DockPosition::Bottom,
329 AssistantDockPosition::Right => DockPosition::Right,
330 }
331 }
332
333 fn position_is_valid(&self, _: DockPosition) -> bool {
334 true
335 }
336
337 fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
338 settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
339 let dock = match position {
340 DockPosition::Left => AssistantDockPosition::Left,
341 DockPosition::Bottom => AssistantDockPosition::Bottom,
342 DockPosition::Right => AssistantDockPosition::Right,
343 };
344 settings.dock = Some(dock);
345 });
346 }
347
348 fn size(&self, cx: &WindowContext) -> f32 {
349 let settings = settings::get::<AssistantSettings>(cx);
350 match self.position(cx) {
351 DockPosition::Left | DockPosition::Right => {
352 self.width.unwrap_or_else(|| settings.default_width)
353 }
354 DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
355 }
356 }
357
358 fn set_size(&mut self, size: f32, cx: &mut ViewContext<Self>) {
359 match self.position(cx) {
360 DockPosition::Left | DockPosition::Right => self.width = Some(size),
361 DockPosition::Bottom => self.height = Some(size),
362 }
363 cx.notify();
364 }
365
366 fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
367 matches!(event, AssistantPanelEvent::ZoomIn)
368 }
369
370 fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
371 matches!(event, AssistantPanelEvent::ZoomOut)
372 }
373
374 fn is_zoomed(&self, cx: &WindowContext) -> bool {
375 self.pane.read(cx).is_zoomed()
376 }
377
378 fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
379 self.pane.update(cx, |pane, cx| pane.set_zoomed(zoomed, cx));
380 }
381
382 fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
383 if active {
384 if self.api_key.borrow().is_none() && !self.has_read_credentials {
385 self.has_read_credentials = true;
386 let api_key = if let Some((_, api_key)) = cx
387 .platform()
388 .read_credentials(OPENAI_API_URL)
389 .log_err()
390 .flatten()
391 {
392 String::from_utf8(api_key).log_err()
393 } else {
394 None
395 };
396 if let Some(api_key) = api_key {
397 *self.api_key.borrow_mut() = Some(api_key);
398 } else if self.api_key_editor.is_none() {
399 self.api_key_editor = Some(build_api_key_editor(cx));
400 cx.notify();
401 }
402 }
403
404 if self.pane.read(cx).items_len() == 0 {
405 self.add_context(cx);
406 }
407 }
408 }
409
410 fn icon_path(&self) -> &'static str {
411 "icons/robot_14.svg"
412 }
413
414 fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
415 ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
416 }
417
418 fn should_change_position_on_event(event: &Self::Event) -> bool {
419 matches!(event, AssistantPanelEvent::DockPositionChanged)
420 }
421
422 fn should_activate_on_event(_: &Self::Event) -> bool {
423 false
424 }
425
426 fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
427 matches!(event, AssistantPanelEvent::Close)
428 }
429
430 fn has_focus(&self, cx: &WindowContext) -> bool {
431 self.pane.read(cx).has_focus()
432 || self
433 .api_key_editor
434 .as_ref()
435 .map_or(false, |editor| editor.is_focused(cx))
436 }
437
438 fn is_focus_event(event: &Self::Event) -> bool {
439 matches!(event, AssistantPanelEvent::Focus)
440 }
441}
442
443enum AssistantEvent {
444 MessagesEdited,
445 SummaryChanged,
446 StreamedCompletion,
447}
448
449struct Assistant {
450 buffer: ModelHandle<Buffer>,
451 message_anchors: Vec<MessageAnchor>,
452 messages_metadata: HashMap<MessageId, MessageMetadata>,
453 next_message_id: MessageId,
454 summary: Option<String>,
455 pending_summary: Task<Option<()>>,
456 completion_count: usize,
457 pending_completions: Vec<PendingCompletion>,
458 model: String,
459 token_count: Option<usize>,
460 max_token_count: usize,
461 pending_token_count: Task<Option<()>>,
462 api_key: Rc<RefCell<Option<String>>>,
463 _subscriptions: Vec<Subscription>,
464}
465
466impl Entity for Assistant {
467 type Event = AssistantEvent;
468}
469
470impl Assistant {
471 fn new(
472 api_key: Rc<RefCell<Option<String>>>,
473 language_registry: Arc<LanguageRegistry>,
474 cx: &mut ModelContext<Self>,
475 ) -> Self {
476 let model = "gpt-3.5-turbo";
477 let markdown = language_registry.language_for_name("Markdown");
478 let buffer = cx.add_model(|cx| {
479 let mut buffer = Buffer::new(0, "", cx);
480 buffer.set_language_registry(language_registry);
481 cx.spawn_weak(|buffer, mut cx| async move {
482 let markdown = markdown.await?;
483 let buffer = buffer
484 .upgrade(&cx)
485 .ok_or_else(|| anyhow!("buffer was dropped"))?;
486 buffer.update(&mut cx, |buffer, cx| {
487 buffer.set_language(Some(markdown), cx)
488 });
489 anyhow::Ok(())
490 })
491 .detach_and_log_err(cx);
492 buffer
493 });
494
495 let mut this = Self {
496 message_anchors: Default::default(),
497 messages_metadata: Default::default(),
498 next_message_id: Default::default(),
499 summary: None,
500 pending_summary: Task::ready(None),
501 completion_count: Default::default(),
502 pending_completions: Default::default(),
503 token_count: None,
504 max_token_count: tiktoken_rs::model::get_context_size(model),
505 pending_token_count: Task::ready(None),
506 model: model.into(),
507 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
508 api_key,
509 buffer,
510 };
511 let message = MessageAnchor {
512 id: MessageId(post_inc(&mut this.next_message_id.0)),
513 start: language::Anchor::MIN,
514 };
515 this.message_anchors.push(message.clone());
516 this.messages_metadata.insert(
517 message.id,
518 MessageMetadata {
519 role: Role::User,
520 sent_at: Local::now(),
521 error: None,
522 },
523 );
524
525 this.count_remaining_tokens(cx);
526 this
527 }
528
529 fn handle_buffer_event(
530 &mut self,
531 _: ModelHandle<Buffer>,
532 event: &language::Event,
533 cx: &mut ModelContext<Self>,
534 ) {
535 match event {
536 language::Event::Edited => {
537 self.count_remaining_tokens(cx);
538 cx.emit(AssistantEvent::MessagesEdited);
539 }
540 _ => {}
541 }
542 }
543
544 fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
545 let messages = self
546 .open_ai_request_messages(cx)
547 .into_iter()
548 .filter_map(|message| {
549 Some(tiktoken_rs::ChatCompletionRequestMessage {
550 role: match message.role {
551 Role::User => "user".into(),
552 Role::Assistant => "assistant".into(),
553 Role::System => "system".into(),
554 },
555 content: message.content,
556 name: None,
557 })
558 })
559 .collect::<Vec<_>>();
560 let model = self.model.clone();
561 self.pending_token_count = cx.spawn_weak(|this, mut cx| {
562 async move {
563 cx.background().timer(Duration::from_millis(200)).await;
564 let token_count = cx
565 .background()
566 .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
567 .await?;
568
569 this.upgrade(&cx)
570 .ok_or_else(|| anyhow!("assistant was dropped"))?
571 .update(&mut cx, |this, cx| {
572 this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
573 this.token_count = Some(token_count);
574 cx.notify()
575 });
576 anyhow::Ok(())
577 }
578 .log_err()
579 });
580 }
581
582 fn remaining_tokens(&self) -> Option<isize> {
583 Some(self.max_token_count as isize - self.token_count? as isize)
584 }
585
586 fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
587 self.model = model;
588 self.count_remaining_tokens(cx);
589 cx.notify();
590 }
591
592 fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<(MessageAnchor, MessageAnchor)> {
593 let request = OpenAIRequest {
594 model: self.model.clone(),
595 messages: self.open_ai_request_messages(cx),
596 stream: true,
597 };
598
599 let api_key = self.api_key.borrow().clone()?;
600 let stream = stream_completion(api_key, cx.background().clone(), request);
601 let assistant_message =
602 self.insert_message_after(self.message_anchors.last()?.id, Role::Assistant, cx)?;
603 let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?;
604 let task = cx.spawn_weak({
605 |this, mut cx| async move {
606 let assistant_message_id = assistant_message.id;
607 let stream_completion = async {
608 let mut messages = stream.await?;
609
610 while let Some(message) = messages.next().await {
611 let mut message = message?;
612 if let Some(choice) = message.choices.pop() {
613 this.upgrade(&cx)
614 .ok_or_else(|| anyhow!("assistant was dropped"))?
615 .update(&mut cx, |this, cx| {
616 let text: Arc<str> = choice.delta.content?.into();
617 let message_ix = this
618 .message_anchors
619 .iter()
620 .position(|message| message.id == assistant_message_id)?;
621 this.buffer.update(cx, |buffer, cx| {
622 let offset = if message_ix + 1 == this.message_anchors.len()
623 {
624 buffer.len()
625 } else {
626 this.message_anchors[message_ix + 1]
627 .start
628 .to_offset(buffer)
629 .saturating_sub(1)
630 };
631 buffer.edit([(offset..offset, text)], None, cx);
632 });
633 cx.emit(AssistantEvent::StreamedCompletion);
634
635 Some(())
636 });
637 }
638 }
639
640 this.upgrade(&cx)
641 .ok_or_else(|| anyhow!("assistant was dropped"))?
642 .update(&mut cx, |this, cx| {
643 this.pending_completions
644 .retain(|completion| completion.id != this.completion_count);
645 this.summarize(cx);
646 });
647
648 anyhow::Ok(())
649 };
650
651 let result = stream_completion.await;
652 if let Some(this) = this.upgrade(&cx) {
653 this.update(&mut cx, |this, cx| {
654 if let Err(error) = result {
655 if let Some(metadata) =
656 this.messages_metadata.get_mut(&assistant_message.id)
657 {
658 metadata.error = Some(error.to_string().trim().into());
659 cx.notify();
660 }
661 }
662 });
663 }
664 }
665 });
666
667 self.pending_completions.push(PendingCompletion {
668 id: post_inc(&mut self.completion_count),
669 _task: task,
670 });
671 Some((assistant_message, user_message))
672 }
673
674 fn cancel_last_assist(&mut self) -> bool {
675 self.pending_completions.pop().is_some()
676 }
677
678 fn cycle_message_role(&mut self, id: MessageId, cx: &mut ModelContext<Self>) {
679 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
680 metadata.role.cycle();
681 cx.emit(AssistantEvent::MessagesEdited);
682 cx.notify();
683 }
684 }
685
686 fn insert_message_after(
687 &mut self,
688 message_id: MessageId,
689 role: Role,
690 cx: &mut ModelContext<Self>,
691 ) -> Option<MessageAnchor> {
692 if let Some(prev_message_ix) = self
693 .message_anchors
694 .iter()
695 .position(|message| message.id == message_id)
696 {
697 let start = self.buffer.update(cx, |buffer, cx| {
698 let offset = self.message_anchors[prev_message_ix + 1..]
699 .iter()
700 .find(|message| message.start.is_valid(buffer))
701 .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
702 buffer.edit([(offset..offset, "\n")], None, cx);
703 buffer.anchor_before(offset + 1)
704 });
705 let message = MessageAnchor {
706 id: MessageId(post_inc(&mut self.next_message_id.0)),
707 start,
708 };
709 self.message_anchors
710 .insert(prev_message_ix + 1, message.clone());
711 self.messages_metadata.insert(
712 message.id,
713 MessageMetadata {
714 role,
715 sent_at: Local::now(),
716 error: None,
717 },
718 );
719 cx.emit(AssistantEvent::MessagesEdited);
720 Some(message)
721 } else {
722 None
723 }
724 }
725
726 fn split_message(
727 &mut self,
728 range: Range<usize>,
729 cx: &mut ModelContext<Self>,
730 ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
731 let start_message = self.message_for_offset(range.start, cx);
732 let end_message = self.message_for_offset(range.end, cx);
733 if let Some((start_message, end_message)) = start_message.zip(end_message) {
734 // Prevent splitting when range spans multiple messages.
735 if start_message.index != end_message.index {
736 return (None, None);
737 }
738
739 let message = start_message;
740 let role = message.role;
741 let mut edited_buffer = false;
742
743 let mut suffix_start = None;
744 if range.start > message.range.start && range.end < message.range.end - 1 {
745 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
746 suffix_start = Some(range.end + 1);
747 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
748 suffix_start = Some(range.end);
749 }
750 }
751
752 let suffix = if let Some(suffix_start) = suffix_start {
753 MessageAnchor {
754 id: MessageId(post_inc(&mut self.next_message_id.0)),
755 start: self.buffer.read(cx).anchor_before(suffix_start),
756 }
757 } else {
758 self.buffer.update(cx, |buffer, cx| {
759 buffer.edit([(range.end..range.end, "\n")], None, cx);
760 });
761 edited_buffer = true;
762 MessageAnchor {
763 id: MessageId(post_inc(&mut self.next_message_id.0)),
764 start: self.buffer.read(cx).anchor_before(range.end + 1),
765 }
766 };
767
768 self.message_anchors
769 .insert(message.index + 1, suffix.clone());
770 self.messages_metadata.insert(
771 suffix.id,
772 MessageMetadata {
773 role,
774 sent_at: Local::now(),
775 error: None,
776 },
777 );
778
779 let new_messages = if range.start == range.end || range.start == message.range.start {
780 (None, Some(suffix))
781 } else {
782 let mut prefix_end = None;
783 if range.start > message.range.start && range.end < message.range.end - 1 {
784 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
785 prefix_end = Some(range.start + 1);
786 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
787 == Some('\n')
788 {
789 prefix_end = Some(range.start);
790 }
791 }
792
793 let selection = if let Some(prefix_end) = prefix_end {
794 cx.emit(AssistantEvent::MessagesEdited);
795 MessageAnchor {
796 id: MessageId(post_inc(&mut self.next_message_id.0)),
797 start: self.buffer.read(cx).anchor_before(prefix_end),
798 }
799 } else {
800 self.buffer.update(cx, |buffer, cx| {
801 buffer.edit([(range.start..range.start, "\n")], None, cx)
802 });
803 edited_buffer = true;
804 MessageAnchor {
805 id: MessageId(post_inc(&mut self.next_message_id.0)),
806 start: self.buffer.read(cx).anchor_before(range.end + 1),
807 }
808 };
809
810 self.message_anchors
811 .insert(message.index + 1, selection.clone());
812 self.messages_metadata.insert(
813 selection.id,
814 MessageMetadata {
815 role,
816 sent_at: Local::now(),
817 error: None,
818 },
819 );
820 (Some(selection), Some(suffix))
821 };
822
823 if !edited_buffer {
824 cx.emit(AssistantEvent::MessagesEdited);
825 }
826 new_messages
827 } else {
828 (None, None)
829 }
830 }
831
832 fn summarize(&mut self, cx: &mut ModelContext<Self>) {
833 if self.message_anchors.len() >= 2 && self.summary.is_none() {
834 let api_key = self.api_key.borrow().clone();
835 if let Some(api_key) = api_key {
836 let mut messages = self.open_ai_request_messages(cx);
837 messages.truncate(2);
838 messages.push(RequestMessage {
839 role: Role::User,
840 content: "Summarize the conversation into a short title without punctuation"
841 .into(),
842 });
843 let request = OpenAIRequest {
844 model: self.model.clone(),
845 messages,
846 stream: true,
847 };
848
849 let stream = stream_completion(api_key, cx.background().clone(), request);
850 self.pending_summary = cx.spawn(|this, mut cx| {
851 async move {
852 let mut messages = stream.await?;
853
854 while let Some(message) = messages.next().await {
855 let mut message = message?;
856 if let Some(choice) = message.choices.pop() {
857 let text = choice.delta.content.unwrap_or_default();
858 this.update(&mut cx, |this, cx| {
859 this.summary.get_or_insert(String::new()).push_str(&text);
860 cx.emit(AssistantEvent::SummaryChanged);
861 });
862 }
863 }
864
865 anyhow::Ok(())
866 }
867 .log_err()
868 });
869 }
870 }
871 }
872
873 fn open_ai_request_messages(&self, cx: &AppContext) -> Vec<RequestMessage> {
874 let buffer = self.buffer.read(cx);
875 self.messages(cx)
876 .map(|message| RequestMessage {
877 role: message.role,
878 content: buffer.text_for_range(message.range).collect(),
879 })
880 .collect()
881 }
882
883 fn message_for_offset<'a>(&'a self, offset: usize, cx: &'a AppContext) -> Option<Message> {
884 let mut messages = self.messages(cx).peekable();
885 while let Some(message) = messages.next() {
886 if message.range.contains(&offset) || messages.peek().is_none() {
887 return Some(message);
888 }
889 }
890 None
891 }
892
893 fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
894 let buffer = self.buffer.read(cx);
895 let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
896 iter::from_fn(move || {
897 while let Some((ix, message_anchor)) = message_anchors.next() {
898 let metadata = self.messages_metadata.get(&message_anchor.id)?;
899 let message_start = message_anchor.start.to_offset(buffer);
900 let mut message_end = None;
901 while let Some((_, next_message)) = message_anchors.peek() {
902 if next_message.start.is_valid(buffer) {
903 message_end = Some(next_message.start);
904 break;
905 } else {
906 message_anchors.next();
907 }
908 }
909 let message_end = message_end
910 .unwrap_or(language::Anchor::MAX)
911 .to_offset(buffer);
912 return Some(Message {
913 index: ix,
914 range: message_start..message_end,
915 id: message_anchor.id,
916 anchor: message_anchor.start,
917 role: metadata.role,
918 sent_at: metadata.sent_at,
919 error: metadata.error.clone(),
920 });
921 }
922 None
923 })
924 }
925}
926
927struct PendingCompletion {
928 id: usize,
929 _task: Task<()>,
930}
931
932enum AssistantEditorEvent {
933 TabContentChanged,
934}
935
936#[derive(Copy, Clone, Debug, PartialEq)]
937struct ScrollPosition {
938 offset_before_cursor: Vector2F,
939 cursor: Anchor,
940}
941
942struct AssistantEditor {
943 assistant: ModelHandle<Assistant>,
944 editor: ViewHandle<Editor>,
945 blocks: HashSet<BlockId>,
946 scroll_position: Option<ScrollPosition>,
947 _subscriptions: Vec<Subscription>,
948}
949
950impl AssistantEditor {
951 fn new(
952 api_key: Rc<RefCell<Option<String>>>,
953 language_registry: Arc<LanguageRegistry>,
954 cx: &mut ViewContext<Self>,
955 ) -> Self {
956 let assistant = cx.add_model(|cx| Assistant::new(api_key, language_registry, cx));
957 let editor = cx.add_view(|cx| {
958 let mut editor = Editor::for_buffer(assistant.read(cx).buffer.clone(), None, cx);
959 editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
960 editor.set_show_gutter(false, cx);
961 editor
962 });
963
964 let _subscriptions = vec![
965 cx.observe(&assistant, |_, _, cx| cx.notify()),
966 cx.subscribe(&assistant, Self::handle_assistant_event),
967 cx.subscribe(&editor, Self::handle_editor_event),
968 ];
969
970 let mut this = Self {
971 assistant,
972 editor,
973 blocks: Default::default(),
974 scroll_position: None,
975 _subscriptions,
976 };
977 this.update_message_headers(cx);
978 this
979 }
980
981 fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
982 let user_message = self.assistant.update(cx, |assistant, cx| {
983 let (_, user_message) = assistant.assist(cx)?;
984 Some(user_message)
985 });
986
987 if let Some(user_message) = user_message {
988 let cursor = user_message
989 .start
990 .to_offset(&self.assistant.read(cx).buffer.read(cx));
991 self.editor.update(cx, |editor, cx| {
992 editor.change_selections(
993 Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
994 cx,
995 |selections| selections.select_ranges([cursor..cursor]),
996 );
997 });
998 }
999 }
1000
1001 fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
1002 if !self
1003 .assistant
1004 .update(cx, |assistant, _| assistant.cancel_last_assist())
1005 {
1006 cx.propagate_action();
1007 }
1008 }
1009
1010 fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
1011 let cursor_offset = self.editor.read(cx).selections.newest(cx).head();
1012 self.assistant.update(cx, |assistant, cx| {
1013 if let Some(message) = assistant.message_for_offset(cursor_offset, cx) {
1014 assistant.cycle_message_role(message.id, cx);
1015 }
1016 });
1017 }
1018
1019 fn handle_assistant_event(
1020 &mut self,
1021 _: ModelHandle<Assistant>,
1022 event: &AssistantEvent,
1023 cx: &mut ViewContext<Self>,
1024 ) {
1025 match event {
1026 AssistantEvent::MessagesEdited => self.update_message_headers(cx),
1027 AssistantEvent::SummaryChanged => {
1028 cx.emit(AssistantEditorEvent::TabContentChanged);
1029 }
1030 AssistantEvent::StreamedCompletion => {
1031 self.editor.update(cx, |editor, cx| {
1032 if let Some(scroll_position) = self.scroll_position {
1033 let snapshot = editor.snapshot(cx);
1034 let cursor_point = scroll_position.cursor.to_display_point(&snapshot);
1035 let scroll_top =
1036 cursor_point.row() as f32 - scroll_position.offset_before_cursor.y();
1037 editor.set_scroll_position(
1038 vec2f(scroll_position.offset_before_cursor.x(), scroll_top),
1039 cx,
1040 );
1041 }
1042 });
1043 }
1044 }
1045 }
1046
1047 fn handle_editor_event(
1048 &mut self,
1049 _: ViewHandle<Editor>,
1050 event: &editor::Event,
1051 cx: &mut ViewContext<Self>,
1052 ) {
1053 match event {
1054 editor::Event::ScrollPositionChanged { autoscroll, .. } => {
1055 let cursor_scroll_position = self.cursor_scroll_position(cx);
1056 if *autoscroll {
1057 self.scroll_position = cursor_scroll_position;
1058 } else if self.scroll_position != cursor_scroll_position {
1059 self.scroll_position = None;
1060 }
1061 }
1062 editor::Event::SelectionsChanged { .. } => {
1063 self.scroll_position = self.cursor_scroll_position(cx);
1064 }
1065 _ => {}
1066 }
1067 }
1068
1069 fn cursor_scroll_position(&self, cx: &mut ViewContext<Self>) -> Option<ScrollPosition> {
1070 self.editor.update(cx, |editor, cx| {
1071 let snapshot = editor.snapshot(cx);
1072 let cursor = editor.selections.newest_anchor().head();
1073 let cursor_row = cursor.to_display_point(&snapshot.display_snapshot).row() as f32;
1074 let scroll_position = editor
1075 .scroll_manager
1076 .anchor()
1077 .scroll_position(&snapshot.display_snapshot);
1078
1079 let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1080 if (scroll_position.y()..scroll_bottom).contains(&cursor_row) {
1081 Some(ScrollPosition {
1082 cursor,
1083 offset_before_cursor: vec2f(
1084 scroll_position.x(),
1085 cursor_row - scroll_position.y(),
1086 ),
1087 })
1088 } else {
1089 None
1090 }
1091 })
1092 }
1093
1094 fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
1095 self.editor.update(cx, |editor, cx| {
1096 let buffer = editor.buffer().read(cx).snapshot(cx);
1097 let excerpt_id = *buffer.as_singleton().unwrap().0;
1098 let old_blocks = std::mem::take(&mut self.blocks);
1099 let new_blocks = self
1100 .assistant
1101 .read(cx)
1102 .messages(cx)
1103 .map(|message| BlockProperties {
1104 position: buffer.anchor_in_excerpt(excerpt_id, message.anchor),
1105 height: 2,
1106 style: BlockStyle::Sticky,
1107 render: Arc::new({
1108 let assistant = self.assistant.clone();
1109 // let metadata = message.metadata.clone();
1110 // let message = message.clone();
1111 move |cx| {
1112 enum Sender {}
1113 enum ErrorTooltip {}
1114
1115 let theme = theme::current(cx);
1116 let style = &theme.assistant;
1117 let message_id = message.id;
1118 let sender = MouseEventHandler::<Sender, _>::new(
1119 message_id.0,
1120 cx,
1121 |state, _| match message.role {
1122 Role::User => {
1123 let style = style.user_sender.style_for(state, false);
1124 Label::new("You", style.text.clone())
1125 .contained()
1126 .with_style(style.container)
1127 }
1128 Role::Assistant => {
1129 let style = style.assistant_sender.style_for(state, false);
1130 Label::new("Assistant", style.text.clone())
1131 .contained()
1132 .with_style(style.container)
1133 }
1134 Role::System => {
1135 let style = style.system_sender.style_for(state, false);
1136 Label::new("System", style.text.clone())
1137 .contained()
1138 .with_style(style.container)
1139 }
1140 },
1141 )
1142 .with_cursor_style(CursorStyle::PointingHand)
1143 .on_down(MouseButton::Left, {
1144 let assistant = assistant.clone();
1145 move |_, _, cx| {
1146 assistant.update(cx, |assistant, cx| {
1147 assistant.cycle_message_role(message_id, cx)
1148 })
1149 }
1150 });
1151
1152 Flex::row()
1153 .with_child(sender.aligned())
1154 .with_child(
1155 Label::new(
1156 message.sent_at.format("%I:%M%P").to_string(),
1157 style.sent_at.text.clone(),
1158 )
1159 .contained()
1160 .with_style(style.sent_at.container)
1161 .aligned(),
1162 )
1163 .with_children(message.error.as_ref().map(|error| {
1164 Svg::new("icons/circle_x_mark_12.svg")
1165 .with_color(style.error_icon.color)
1166 .constrained()
1167 .with_width(style.error_icon.width)
1168 .contained()
1169 .with_style(style.error_icon.container)
1170 .with_tooltip::<ErrorTooltip>(
1171 message_id.0,
1172 error.to_string(),
1173 None,
1174 theme.tooltip.clone(),
1175 cx,
1176 )
1177 .aligned()
1178 }))
1179 .aligned()
1180 .left()
1181 .contained()
1182 .with_style(style.header)
1183 .into_any()
1184 }
1185 }),
1186 disposition: BlockDisposition::Above,
1187 })
1188 .collect::<Vec<_>>();
1189
1190 editor.remove_blocks(old_blocks, None, cx);
1191 let ids = editor.insert_blocks(new_blocks, None, cx);
1192 self.blocks = HashSet::from_iter(ids);
1193 });
1194 }
1195
1196 fn quote_selection(
1197 workspace: &mut Workspace,
1198 _: &QuoteSelection,
1199 cx: &mut ViewContext<Workspace>,
1200 ) {
1201 let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1202 return;
1203 };
1204 let Some(editor) = workspace.active_item(cx).and_then(|item| item.downcast::<Editor>()) else {
1205 return;
1206 };
1207
1208 let text = editor.read_with(cx, |editor, cx| {
1209 let range = editor.selections.newest::<usize>(cx).range();
1210 let buffer = editor.buffer().read(cx).snapshot(cx);
1211 let start_language = buffer.language_at(range.start);
1212 let end_language = buffer.language_at(range.end);
1213 let language_name = if start_language == end_language {
1214 start_language.map(|language| language.name())
1215 } else {
1216 None
1217 };
1218 let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1219
1220 let selected_text = buffer.text_for_range(range).collect::<String>();
1221 if selected_text.is_empty() {
1222 None
1223 } else {
1224 Some(if language_name == "markdown" {
1225 selected_text
1226 .lines()
1227 .map(|line| format!("> {}", line))
1228 .collect::<Vec<_>>()
1229 .join("\n")
1230 } else {
1231 format!("```{language_name}\n{selected_text}\n```")
1232 })
1233 }
1234 });
1235
1236 // Activate the panel
1237 if !panel.read(cx).has_focus(cx) {
1238 workspace.toggle_panel_focus::<AssistantPanel>(cx);
1239 }
1240
1241 if let Some(text) = text {
1242 panel.update(cx, |panel, cx| {
1243 if let Some(assistant) = panel
1244 .pane
1245 .read(cx)
1246 .active_item()
1247 .and_then(|item| item.downcast::<AssistantEditor>())
1248 .ok_or_else(|| anyhow!("no active context"))
1249 .log_err()
1250 {
1251 assistant.update(cx, |assistant, cx| {
1252 assistant
1253 .editor
1254 .update(cx, |editor, cx| editor.insert(&text, cx))
1255 });
1256 }
1257 });
1258 }
1259 }
1260
1261 fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1262 let editor = self.editor.read(cx);
1263 let assistant = self.assistant.read(cx);
1264 if editor.selections.count() == 1 {
1265 let selection = editor.selections.newest::<usize>(cx);
1266 let mut copied_text = String::new();
1267 let mut spanned_messages = 0;
1268 for message in assistant.messages(cx) {
1269 if message.range.start >= selection.range().end {
1270 break;
1271 } else if message.range.end >= selection.range().start {
1272 let range = cmp::max(message.range.start, selection.range().start)
1273 ..cmp::min(message.range.end, selection.range().end);
1274 if !range.is_empty() {
1275 spanned_messages += 1;
1276 write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
1277 for chunk in assistant.buffer.read(cx).text_for_range(range) {
1278 copied_text.push_str(&chunk);
1279 }
1280 copied_text.push('\n');
1281 }
1282 }
1283 }
1284
1285 if spanned_messages > 1 {
1286 cx.platform()
1287 .write_to_clipboard(ClipboardItem::new(copied_text));
1288 return;
1289 }
1290 }
1291
1292 cx.propagate_action();
1293 }
1294
1295 fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
1296 self.assistant.update(cx, |assistant, cx| {
1297 let selections = self.editor.read(cx).selections.disjoint_anchors();
1298 for selection in selections.into_iter() {
1299 let buffer = self.editor.read(cx).buffer().read(cx).snapshot(cx);
1300 let range = selection
1301 .map(|endpoint| endpoint.to_offset(&buffer))
1302 .range();
1303 assistant.split_message(range, cx);
1304 }
1305 });
1306 }
1307
1308 fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
1309 self.assistant.update(cx, |assistant, cx| {
1310 let new_model = match assistant.model.as_str() {
1311 "gpt-4" => "gpt-3.5-turbo",
1312 _ => "gpt-4",
1313 };
1314 assistant.set_model(new_model.into(), cx);
1315 });
1316 }
1317
1318 fn title(&self, cx: &AppContext) -> String {
1319 self.assistant
1320 .read(cx)
1321 .summary
1322 .clone()
1323 .unwrap_or_else(|| "New Context".into())
1324 }
1325}
1326
1327impl Entity for AssistantEditor {
1328 type Event = AssistantEditorEvent;
1329}
1330
1331impl View for AssistantEditor {
1332 fn ui_name() -> &'static str {
1333 "AssistantEditor"
1334 }
1335
1336 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
1337 enum Model {}
1338 let theme = &theme::current(cx).assistant;
1339 let assistant = &self.assistant.read(cx);
1340 let model = assistant.model.clone();
1341 let remaining_tokens = assistant.remaining_tokens().map(|remaining_tokens| {
1342 let remaining_tokens_style = if remaining_tokens <= 0 {
1343 &theme.no_remaining_tokens
1344 } else {
1345 &theme.remaining_tokens
1346 };
1347 Label::new(
1348 remaining_tokens.to_string(),
1349 remaining_tokens_style.text.clone(),
1350 )
1351 .contained()
1352 .with_style(remaining_tokens_style.container)
1353 });
1354
1355 Stack::new()
1356 .with_child(
1357 ChildView::new(&self.editor, cx)
1358 .contained()
1359 .with_style(theme.container),
1360 )
1361 .with_child(
1362 Flex::row()
1363 .with_child(
1364 MouseEventHandler::<Model, _>::new(0, cx, |state, _| {
1365 let style = theme.model.style_for(state, false);
1366 Label::new(model, style.text.clone())
1367 .contained()
1368 .with_style(style.container)
1369 })
1370 .with_cursor_style(CursorStyle::PointingHand)
1371 .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx)),
1372 )
1373 .with_children(remaining_tokens)
1374 .contained()
1375 .with_style(theme.model_info_container)
1376 .aligned()
1377 .top()
1378 .right(),
1379 )
1380 .into_any()
1381 }
1382
1383 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
1384 if cx.is_self_focused() {
1385 cx.focus(&self.editor);
1386 }
1387 }
1388}
1389
1390impl Item for AssistantEditor {
1391 fn tab_content<V: View>(
1392 &self,
1393 _: Option<usize>,
1394 style: &theme::Tab,
1395 cx: &gpui::AppContext,
1396 ) -> AnyElement<V> {
1397 let title = truncate_and_trailoff(&self.title(cx), editor::MAX_TAB_TITLE_LEN);
1398 Label::new(title, style.label.clone()).into_any()
1399 }
1400
1401 fn tab_tooltip_text(&self, cx: &AppContext) -> Option<Cow<str>> {
1402 Some(self.title(cx).into())
1403 }
1404
1405 fn as_searchable(
1406 &self,
1407 _: &ViewHandle<Self>,
1408 ) -> Option<Box<dyn workspace::searchable::SearchableItemHandle>> {
1409 Some(Box::new(self.editor.clone()))
1410 }
1411}
1412
1413#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash)]
1414struct MessageId(usize);
1415
1416#[derive(Clone, Debug)]
1417struct MessageAnchor {
1418 id: MessageId,
1419 start: language::Anchor,
1420}
1421
1422#[derive(Clone, Debug)]
1423struct MessageMetadata {
1424 role: Role,
1425 sent_at: DateTime<Local>,
1426 error: Option<Arc<str>>,
1427}
1428
1429#[derive(Clone, Debug)]
1430pub struct Message {
1431 range: Range<usize>,
1432 index: usize,
1433 id: MessageId,
1434 anchor: language::Anchor,
1435 role: Role,
1436 sent_at: DateTime<Local>,
1437 error: Option<Arc<str>>,
1438}
1439
1440async fn stream_completion(
1441 api_key: String,
1442 executor: Arc<Background>,
1443 mut request: OpenAIRequest,
1444) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
1445 request.stream = true;
1446
1447 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
1448
1449 let json_data = serde_json::to_string(&request)?;
1450 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
1451 .header("Content-Type", "application/json")
1452 .header("Authorization", format!("Bearer {}", api_key))
1453 .body(json_data)?
1454 .send_async()
1455 .await?;
1456
1457 let status = response.status();
1458 if status == StatusCode::OK {
1459 executor
1460 .spawn(async move {
1461 let mut lines = BufReader::new(response.body_mut()).lines();
1462
1463 fn parse_line(
1464 line: Result<String, io::Error>,
1465 ) -> Result<Option<OpenAIResponseStreamEvent>> {
1466 if let Some(data) = line?.strip_prefix("data: ") {
1467 let event = serde_json::from_str(&data)?;
1468 Ok(Some(event))
1469 } else {
1470 Ok(None)
1471 }
1472 }
1473
1474 while let Some(line) = lines.next().await {
1475 if let Some(event) = parse_line(line).transpose() {
1476 let done = event.as_ref().map_or(false, |event| {
1477 event
1478 .choices
1479 .last()
1480 .map_or(false, |choice| choice.finish_reason.is_some())
1481 });
1482 if tx.unbounded_send(event).is_err() {
1483 break;
1484 }
1485
1486 if done {
1487 break;
1488 }
1489 }
1490 }
1491
1492 anyhow::Ok(())
1493 })
1494 .detach();
1495
1496 Ok(rx)
1497 } else {
1498 let mut body = String::new();
1499 response.body_mut().read_to_string(&mut body).await?;
1500
1501 #[derive(Deserialize)]
1502 struct OpenAIResponse {
1503 error: OpenAIError,
1504 }
1505
1506 #[derive(Deserialize)]
1507 struct OpenAIError {
1508 message: String,
1509 }
1510
1511 match serde_json::from_str::<OpenAIResponse>(&body) {
1512 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
1513 "Failed to connect to OpenAI API: {}",
1514 response.error.message,
1515 )),
1516
1517 _ => Err(anyhow!(
1518 "Failed to connect to OpenAI API: {} {}",
1519 response.status(),
1520 body,
1521 )),
1522 }
1523 }
1524}
1525
1526#[cfg(test)]
1527mod tests {
1528 use super::*;
1529 use gpui::AppContext;
1530
1531 #[gpui::test]
1532 fn test_inserting_and_removing_messages(cx: &mut AppContext) {
1533 let registry = Arc::new(LanguageRegistry::test());
1534 let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
1535 let buffer = assistant.read(cx).buffer.clone();
1536
1537 let message_1 = assistant.read(cx).message_anchors[0].clone();
1538 assert_eq!(
1539 messages(&assistant, cx),
1540 vec![(message_1.id, Role::User, 0..0)]
1541 );
1542
1543 let message_2 = assistant.update(cx, |assistant, cx| {
1544 assistant
1545 .insert_message_after(message_1.id, Role::Assistant, cx)
1546 .unwrap()
1547 });
1548 assert_eq!(
1549 messages(&assistant, cx),
1550 vec![
1551 (message_1.id, Role::User, 0..1),
1552 (message_2.id, Role::Assistant, 1..1)
1553 ]
1554 );
1555
1556 buffer.update(cx, |buffer, cx| {
1557 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
1558 });
1559 assert_eq!(
1560 messages(&assistant, cx),
1561 vec![
1562 (message_1.id, Role::User, 0..2),
1563 (message_2.id, Role::Assistant, 2..3)
1564 ]
1565 );
1566
1567 let message_3 = assistant.update(cx, |assistant, cx| {
1568 assistant
1569 .insert_message_after(message_2.id, Role::User, cx)
1570 .unwrap()
1571 });
1572 assert_eq!(
1573 messages(&assistant, cx),
1574 vec![
1575 (message_1.id, Role::User, 0..2),
1576 (message_2.id, Role::Assistant, 2..4),
1577 (message_3.id, Role::User, 4..4)
1578 ]
1579 );
1580
1581 let message_4 = assistant.update(cx, |assistant, cx| {
1582 assistant
1583 .insert_message_after(message_2.id, Role::User, cx)
1584 .unwrap()
1585 });
1586 assert_eq!(
1587 messages(&assistant, cx),
1588 vec![
1589 (message_1.id, Role::User, 0..2),
1590 (message_2.id, Role::Assistant, 2..4),
1591 (message_4.id, Role::User, 4..5),
1592 (message_3.id, Role::User, 5..5),
1593 ]
1594 );
1595
1596 buffer.update(cx, |buffer, cx| {
1597 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
1598 });
1599 assert_eq!(
1600 messages(&assistant, cx),
1601 vec![
1602 (message_1.id, Role::User, 0..2),
1603 (message_2.id, Role::Assistant, 2..4),
1604 (message_4.id, Role::User, 4..6),
1605 (message_3.id, Role::User, 6..7),
1606 ]
1607 );
1608
1609 // Deleting across message boundaries merges the messages.
1610 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
1611 assert_eq!(
1612 messages(&assistant, cx),
1613 vec![
1614 (message_1.id, Role::User, 0..3),
1615 (message_3.id, Role::User, 3..4),
1616 ]
1617 );
1618
1619 // Undoing the deletion should also undo the merge.
1620 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1621 assert_eq!(
1622 messages(&assistant, cx),
1623 vec![
1624 (message_1.id, Role::User, 0..2),
1625 (message_2.id, Role::Assistant, 2..4),
1626 (message_4.id, Role::User, 4..6),
1627 (message_3.id, Role::User, 6..7),
1628 ]
1629 );
1630
1631 // Redoing the deletion should also redo the merge.
1632 buffer.update(cx, |buffer, cx| buffer.redo(cx));
1633 assert_eq!(
1634 messages(&assistant, cx),
1635 vec![
1636 (message_1.id, Role::User, 0..3),
1637 (message_3.id, Role::User, 3..4),
1638 ]
1639 );
1640
1641 // Ensure we can still insert after a merged message.
1642 let message_5 = assistant.update(cx, |assistant, cx| {
1643 assistant
1644 .insert_message_after(message_1.id, Role::System, cx)
1645 .unwrap()
1646 });
1647 assert_eq!(
1648 messages(&assistant, cx),
1649 vec![
1650 (message_1.id, Role::User, 0..3),
1651 (message_5.id, Role::System, 3..4),
1652 (message_3.id, Role::User, 4..5)
1653 ]
1654 );
1655 }
1656
1657 #[gpui::test]
1658 fn test_message_splitting(cx: &mut AppContext) {
1659 let registry = Arc::new(LanguageRegistry::test());
1660 let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
1661 let buffer = assistant.read(cx).buffer.clone();
1662
1663 let message_1 = assistant.read(cx).message_anchors[0].clone();
1664 assert_eq!(
1665 messages(&assistant, cx),
1666 vec![(message_1.id, Role::User, 0..0)]
1667 );
1668
1669 buffer.update(cx, |buffer, cx| {
1670 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
1671 });
1672
1673 let (_, message_2) =
1674 assistant.update(cx, |assistant, cx| assistant.split_message(3..3, cx));
1675 let message_2 = message_2.unwrap();
1676
1677 // We recycle newlines in the middle of a split message
1678 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
1679 assert_eq!(
1680 messages(&assistant, cx),
1681 vec![
1682 (message_1.id, Role::User, 0..4),
1683 (message_2.id, Role::User, 4..16),
1684 ]
1685 );
1686
1687 let (_, message_3) =
1688 assistant.update(cx, |assistant, cx| assistant.split_message(3..3, cx));
1689 let message_3 = message_3.unwrap();
1690
1691 // We don't recycle newlines at the end of a split message
1692 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
1693 assert_eq!(
1694 messages(&assistant, cx),
1695 vec![
1696 (message_1.id, Role::User, 0..4),
1697 (message_3.id, Role::User, 4..5),
1698 (message_2.id, Role::User, 5..17),
1699 ]
1700 );
1701
1702 let (_, message_4) =
1703 assistant.update(cx, |assistant, cx| assistant.split_message(9..9, cx));
1704 let message_4 = message_4.unwrap();
1705 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
1706 assert_eq!(
1707 messages(&assistant, cx),
1708 vec![
1709 (message_1.id, Role::User, 0..4),
1710 (message_3.id, Role::User, 4..5),
1711 (message_2.id, Role::User, 5..9),
1712 (message_4.id, Role::User, 9..17),
1713 ]
1714 );
1715
1716 let (_, message_5) =
1717 assistant.update(cx, |assistant, cx| assistant.split_message(9..9, cx));
1718 let message_5 = message_5.unwrap();
1719 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
1720 assert_eq!(
1721 messages(&assistant, cx),
1722 vec![
1723 (message_1.id, Role::User, 0..4),
1724 (message_3.id, Role::User, 4..5),
1725 (message_2.id, Role::User, 5..9),
1726 (message_4.id, Role::User, 9..10),
1727 (message_5.id, Role::User, 10..18),
1728 ]
1729 );
1730
1731 let (message_6, message_7) =
1732 assistant.update(cx, |assistant, cx| assistant.split_message(14..16, cx));
1733 let message_6 = message_6.unwrap();
1734 let message_7 = message_7.unwrap();
1735 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
1736 assert_eq!(
1737 messages(&assistant, cx),
1738 vec![
1739 (message_1.id, Role::User, 0..4),
1740 (message_3.id, Role::User, 4..5),
1741 (message_2.id, Role::User, 5..9),
1742 (message_4.id, Role::User, 9..10),
1743 (message_5.id, Role::User, 10..14),
1744 (message_6.id, Role::User, 14..17),
1745 (message_7.id, Role::User, 17..19),
1746 ]
1747 );
1748 }
1749
1750 fn messages(
1751 assistant: &ModelHandle<Assistant>,
1752 cx: &AppContext,
1753 ) -> Vec<(MessageId, Role, Range<usize>)> {
1754 assistant
1755 .read(cx)
1756 .messages(cx)
1757 .map(|message| (message.id, message.role, message.range))
1758 .collect()
1759 }
1760}