1use crate::{
2 assistant_settings::{AssistantDockPosition, AssistantSettings},
3 OpenAIRequest, OpenAIResponseStreamEvent, RequestMessage, Role,
4};
5use anyhow::{anyhow, Result};
6use chrono::{DateTime, Local};
7use collections::{HashMap, HashSet};
8use editor::{
9 display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
10 scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
11 Anchor, Editor,
12};
13use fs::Fs;
14use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
15use gpui::{
16 actions,
17 elements::*,
18 executor::Background,
19 geometry::vector::{vec2f, Vector2F},
20 platform::{CursorStyle, MouseButton},
21 Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
22 Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
23};
24use isahc::{http::StatusCode, Request, RequestExt};
25use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
26use serde::Deserialize;
27use settings::SettingsStore;
28use std::{
29 borrow::Cow, cell::RefCell, cmp, fmt::Write, io, iter, ops::Range, rc::Rc, sync::Arc,
30 time::Duration,
31};
32use util::{channel::ReleaseChannel, post_inc, truncate_and_trailoff, ResultExt, TryFutureExt};
33use workspace::{
34 dock::{DockPosition, Panel},
35 item::Item,
36 pane, Pane, Workspace,
37};
38
39const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
40
41actions!(
42 assistant,
43 [
44 NewContext,
45 Assist,
46 Split,
47 QuoteSelection,
48 ToggleFocus,
49 ResetKey
50 ]
51);
52
53pub fn init(cx: &mut AppContext) {
54 if *util::channel::RELEASE_CHANNEL == ReleaseChannel::Stable {
55 cx.update_default_global::<collections::CommandPaletteFilter, _, _>(move |filter, _cx| {
56 filter.filtered_namespaces.insert("assistant");
57 });
58 }
59
60 settings::register::<AssistantSettings>(cx);
61 cx.add_action(
62 |workspace: &mut Workspace, _: &NewContext, cx: &mut ViewContext<Workspace>| {
63 if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
64 this.update(cx, |this, cx| this.add_context(cx))
65 }
66
67 workspace.focus_panel::<AssistantPanel>(cx);
68 },
69 );
70 cx.add_action(AssistantEditor::assist);
71 cx.capture_action(AssistantEditor::cancel_last_assist);
72 cx.add_action(AssistantEditor::quote_selection);
73 cx.capture_action(AssistantEditor::copy);
74 cx.capture_action(AssistantEditor::split);
75 cx.add_action(AssistantPanel::save_api_key);
76 cx.add_action(AssistantPanel::reset_api_key);
77 cx.add_action(
78 |workspace: &mut Workspace, _: &ToggleFocus, cx: &mut ViewContext<Workspace>| {
79 workspace.toggle_panel_focus::<AssistantPanel>(cx);
80 },
81 );
82}
83
84pub enum AssistantPanelEvent {
85 ZoomIn,
86 ZoomOut,
87 Focus,
88 Close,
89 DockPositionChanged,
90}
91
92pub struct AssistantPanel {
93 width: Option<f32>,
94 height: Option<f32>,
95 pane: ViewHandle<Pane>,
96 api_key: Rc<RefCell<Option<String>>>,
97 api_key_editor: Option<ViewHandle<Editor>>,
98 has_read_credentials: bool,
99 languages: Arc<LanguageRegistry>,
100 fs: Arc<dyn Fs>,
101 subscriptions: Vec<Subscription>,
102}
103
104impl AssistantPanel {
105 pub fn load(
106 workspace: WeakViewHandle<Workspace>,
107 cx: AsyncAppContext,
108 ) -> Task<Result<ViewHandle<Self>>> {
109 cx.spawn(|mut cx| async move {
110 // TODO: deserialize state.
111 workspace.update(&mut cx, |workspace, cx| {
112 cx.add_view::<Self, _>(|cx| {
113 let weak_self = cx.weak_handle();
114 let pane = cx.add_view(|cx| {
115 let mut pane = Pane::new(
116 workspace.weak_handle(),
117 workspace.project().clone(),
118 workspace.app_state().background_actions,
119 Default::default(),
120 cx,
121 );
122 pane.set_can_split(false, cx);
123 pane.set_can_navigate(false, cx);
124 pane.on_can_drop(move |_, _| false);
125 pane.set_render_tab_bar_buttons(cx, move |pane, cx| {
126 let weak_self = weak_self.clone();
127 Flex::row()
128 .with_child(Pane::render_tab_bar_button(
129 0,
130 "icons/plus_12.svg",
131 false,
132 Some(("New Context".into(), Some(Box::new(NewContext)))),
133 cx,
134 move |_, cx| {
135 let weak_self = weak_self.clone();
136 cx.window_context().defer(move |cx| {
137 if let Some(this) = weak_self.upgrade(cx) {
138 this.update(cx, |this, cx| this.add_context(cx));
139 }
140 })
141 },
142 None,
143 ))
144 .with_child(Pane::render_tab_bar_button(
145 1,
146 if pane.is_zoomed() {
147 "icons/minimize_8.svg"
148 } else {
149 "icons/maximize_8.svg"
150 },
151 pane.is_zoomed(),
152 Some((
153 "Toggle Zoom".into(),
154 Some(Box::new(workspace::ToggleZoom)),
155 )),
156 cx,
157 move |pane, cx| pane.toggle_zoom(&Default::default(), cx),
158 None,
159 ))
160 .into_any()
161 });
162 let buffer_search_bar = cx.add_view(search::BufferSearchBar::new);
163 pane.toolbar()
164 .update(cx, |toolbar, cx| toolbar.add_item(buffer_search_bar, cx));
165 pane
166 });
167
168 let mut this = Self {
169 pane,
170 api_key: Rc::new(RefCell::new(None)),
171 api_key_editor: None,
172 has_read_credentials: false,
173 languages: workspace.app_state().languages.clone(),
174 fs: workspace.app_state().fs.clone(),
175 width: None,
176 height: None,
177 subscriptions: Default::default(),
178 };
179
180 let mut old_dock_position = this.position(cx);
181 this.subscriptions = vec![
182 cx.observe(&this.pane, |_, _, cx| cx.notify()),
183 cx.subscribe(&this.pane, Self::handle_pane_event),
184 cx.observe_global::<SettingsStore, _>(move |this, cx| {
185 let new_dock_position = this.position(cx);
186 if new_dock_position != old_dock_position {
187 old_dock_position = new_dock_position;
188 cx.emit(AssistantPanelEvent::DockPositionChanged);
189 }
190 }),
191 ];
192
193 this
194 })
195 })
196 })
197 }
198
199 fn handle_pane_event(
200 &mut self,
201 _pane: ViewHandle<Pane>,
202 event: &pane::Event,
203 cx: &mut ViewContext<Self>,
204 ) {
205 match event {
206 pane::Event::ZoomIn => cx.emit(AssistantPanelEvent::ZoomIn),
207 pane::Event::ZoomOut => cx.emit(AssistantPanelEvent::ZoomOut),
208 pane::Event::Focus => cx.emit(AssistantPanelEvent::Focus),
209 pane::Event::Remove => cx.emit(AssistantPanelEvent::Close),
210 _ => {}
211 }
212 }
213
214 fn add_context(&mut self, cx: &mut ViewContext<Self>) {
215 let focus = self.has_focus(cx);
216 let editor = cx
217 .add_view(|cx| AssistantEditor::new(self.api_key.clone(), self.languages.clone(), cx));
218 self.subscriptions
219 .push(cx.subscribe(&editor, Self::handle_assistant_editor_event));
220 self.pane.update(cx, |pane, cx| {
221 pane.add_item(Box::new(editor), true, focus, None, cx)
222 });
223 }
224
225 fn handle_assistant_editor_event(
226 &mut self,
227 _: ViewHandle<AssistantEditor>,
228 event: &AssistantEditorEvent,
229 cx: &mut ViewContext<Self>,
230 ) {
231 match event {
232 AssistantEditorEvent::TabContentChanged => self.pane.update(cx, |_, cx| cx.notify()),
233 }
234 }
235
236 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
237 if let Some(api_key) = self
238 .api_key_editor
239 .as_ref()
240 .map(|editor| editor.read(cx).text(cx))
241 {
242 if !api_key.is_empty() {
243 cx.platform()
244 .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
245 .log_err();
246 *self.api_key.borrow_mut() = Some(api_key);
247 self.api_key_editor.take();
248 cx.focus_self();
249 cx.notify();
250 }
251 } else {
252 cx.propagate_action();
253 }
254 }
255
256 fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
257 cx.platform().delete_credentials(OPENAI_API_URL).log_err();
258 self.api_key.take();
259 self.api_key_editor = Some(build_api_key_editor(cx));
260 cx.focus_self();
261 cx.notify();
262 }
263}
264
265fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
266 cx.add_view(|cx| {
267 let mut editor = Editor::single_line(
268 Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
269 cx,
270 );
271 editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
272 editor
273 })
274}
275
276impl Entity for AssistantPanel {
277 type Event = AssistantPanelEvent;
278}
279
280impl View for AssistantPanel {
281 fn ui_name() -> &'static str {
282 "AssistantPanel"
283 }
284
285 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
286 let style = &theme::current(cx).assistant;
287 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
288 Flex::column()
289 .with_child(
290 Text::new(
291 "Paste your OpenAI API key and press Enter to use the assistant",
292 style.api_key_prompt.text.clone(),
293 )
294 .aligned(),
295 )
296 .with_child(
297 ChildView::new(api_key_editor, cx)
298 .contained()
299 .with_style(style.api_key_editor.container)
300 .aligned(),
301 )
302 .contained()
303 .with_style(style.api_key_prompt.container)
304 .aligned()
305 .into_any()
306 } else {
307 ChildView::new(&self.pane, cx).into_any()
308 }
309 }
310
311 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
312 if cx.is_self_focused() {
313 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
314 cx.focus(api_key_editor);
315 } else {
316 cx.focus(&self.pane);
317 }
318 }
319 }
320}
321
322impl Panel for AssistantPanel {
323 fn position(&self, cx: &WindowContext) -> DockPosition {
324 match settings::get::<AssistantSettings>(cx).dock {
325 AssistantDockPosition::Left => DockPosition::Left,
326 AssistantDockPosition::Bottom => DockPosition::Bottom,
327 AssistantDockPosition::Right => DockPosition::Right,
328 }
329 }
330
331 fn position_is_valid(&self, _: DockPosition) -> bool {
332 true
333 }
334
335 fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
336 settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
337 let dock = match position {
338 DockPosition::Left => AssistantDockPosition::Left,
339 DockPosition::Bottom => AssistantDockPosition::Bottom,
340 DockPosition::Right => AssistantDockPosition::Right,
341 };
342 settings.dock = Some(dock);
343 });
344 }
345
346 fn size(&self, cx: &WindowContext) -> f32 {
347 let settings = settings::get::<AssistantSettings>(cx);
348 match self.position(cx) {
349 DockPosition::Left | DockPosition::Right => {
350 self.width.unwrap_or_else(|| settings.default_width)
351 }
352 DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
353 }
354 }
355
356 fn set_size(&mut self, size: f32, cx: &mut ViewContext<Self>) {
357 match self.position(cx) {
358 DockPosition::Left | DockPosition::Right => self.width = Some(size),
359 DockPosition::Bottom => self.height = Some(size),
360 }
361 cx.notify();
362 }
363
364 fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
365 matches!(event, AssistantPanelEvent::ZoomIn)
366 }
367
368 fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
369 matches!(event, AssistantPanelEvent::ZoomOut)
370 }
371
372 fn is_zoomed(&self, cx: &WindowContext) -> bool {
373 self.pane.read(cx).is_zoomed()
374 }
375
376 fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
377 self.pane.update(cx, |pane, cx| pane.set_zoomed(zoomed, cx));
378 }
379
380 fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
381 if active {
382 if self.api_key.borrow().is_none() && !self.has_read_credentials {
383 self.has_read_credentials = true;
384 let api_key = if let Some((_, api_key)) = cx
385 .platform()
386 .read_credentials(OPENAI_API_URL)
387 .log_err()
388 .flatten()
389 {
390 String::from_utf8(api_key).log_err()
391 } else {
392 None
393 };
394 if let Some(api_key) = api_key {
395 *self.api_key.borrow_mut() = Some(api_key);
396 } else if self.api_key_editor.is_none() {
397 self.api_key_editor = Some(build_api_key_editor(cx));
398 cx.notify();
399 }
400 }
401
402 if self.pane.read(cx).items_len() == 0 {
403 self.add_context(cx);
404 }
405 }
406 }
407
408 fn icon_path(&self) -> &'static str {
409 "icons/robot_14.svg"
410 }
411
412 fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
413 ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
414 }
415
416 fn should_change_position_on_event(event: &Self::Event) -> bool {
417 matches!(event, AssistantPanelEvent::DockPositionChanged)
418 }
419
420 fn should_activate_on_event(_: &Self::Event) -> bool {
421 false
422 }
423
424 fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
425 matches!(event, AssistantPanelEvent::Close)
426 }
427
428 fn has_focus(&self, cx: &WindowContext) -> bool {
429 self.pane.read(cx).has_focus()
430 || self
431 .api_key_editor
432 .as_ref()
433 .map_or(false, |editor| editor.is_focused(cx))
434 }
435
436 fn is_focus_event(event: &Self::Event) -> bool {
437 matches!(event, AssistantPanelEvent::Focus)
438 }
439}
440
441enum AssistantEvent {
442 MessagesEdited,
443 SummaryChanged,
444 StreamedCompletion,
445}
446
447struct Assistant {
448 buffer: ModelHandle<Buffer>,
449 messages: Vec<Message>,
450 messages_metadata: HashMap<MessageId, MessageMetadata>,
451 next_message_id: MessageId,
452 summary: Option<String>,
453 pending_summary: Task<Option<()>>,
454 completion_count: usize,
455 pending_completions: Vec<PendingCompletion>,
456 model: String,
457 token_count: Option<usize>,
458 max_token_count: usize,
459 pending_token_count: Task<Option<()>>,
460 api_key: Rc<RefCell<Option<String>>>,
461 _subscriptions: Vec<Subscription>,
462}
463
464impl Entity for Assistant {
465 type Event = AssistantEvent;
466}
467
468impl Assistant {
469 fn new(
470 api_key: Rc<RefCell<Option<String>>>,
471 language_registry: Arc<LanguageRegistry>,
472 cx: &mut ModelContext<Self>,
473 ) -> Self {
474 let model = "gpt-3.5-turbo";
475 let markdown = language_registry.language_for_name("Markdown");
476 let buffer = cx.add_model(|cx| {
477 let mut buffer = Buffer::new(0, "", cx);
478 buffer.set_language_registry(language_registry);
479 cx.spawn_weak(|buffer, mut cx| async move {
480 let markdown = markdown.await?;
481 let buffer = buffer
482 .upgrade(&cx)
483 .ok_or_else(|| anyhow!("buffer was dropped"))?;
484 buffer.update(&mut cx, |buffer, cx| {
485 buffer.set_language(Some(markdown), cx)
486 });
487 anyhow::Ok(())
488 })
489 .detach_and_log_err(cx);
490 buffer
491 });
492
493 let mut this = Self {
494 messages: Default::default(),
495 messages_metadata: Default::default(),
496 next_message_id: Default::default(),
497 summary: None,
498 pending_summary: Task::ready(None),
499 completion_count: Default::default(),
500 pending_completions: Default::default(),
501 token_count: None,
502 max_token_count: tiktoken_rs::model::get_context_size(model),
503 pending_token_count: Task::ready(None),
504 model: model.into(),
505 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
506 api_key,
507 buffer,
508 };
509 let message = Message {
510 id: MessageId(post_inc(&mut this.next_message_id.0)),
511 start: language::Anchor::MIN,
512 };
513 this.messages.push(message.clone());
514 this.messages_metadata.insert(
515 message.id,
516 MessageMetadata {
517 role: Role::User,
518 sent_at: Local::now(),
519 error: None,
520 },
521 );
522
523 this.count_remaining_tokens(cx);
524 this
525 }
526
527 fn handle_buffer_event(
528 &mut self,
529 _: ModelHandle<Buffer>,
530 event: &language::Event,
531 cx: &mut ModelContext<Self>,
532 ) {
533 match event {
534 language::Event::Edited => {
535 self.count_remaining_tokens(cx);
536 cx.emit(AssistantEvent::MessagesEdited);
537 }
538 _ => {}
539 }
540 }
541
542 fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
543 let messages = self
544 .open_ai_request_messages(cx)
545 .into_iter()
546 .filter_map(|message| {
547 Some(tiktoken_rs::ChatCompletionRequestMessage {
548 role: match message.role {
549 Role::User => "user".into(),
550 Role::Assistant => "assistant".into(),
551 Role::System => "system".into(),
552 },
553 content: message.content,
554 name: None,
555 })
556 })
557 .collect::<Vec<_>>();
558 let model = self.model.clone();
559 self.pending_token_count = cx.spawn_weak(|this, mut cx| {
560 async move {
561 cx.background().timer(Duration::from_millis(200)).await;
562 let token_count = cx
563 .background()
564 .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
565 .await?;
566
567 this.upgrade(&cx)
568 .ok_or_else(|| anyhow!("assistant was dropped"))?
569 .update(&mut cx, |this, cx| {
570 this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
571 this.token_count = Some(token_count);
572 cx.notify()
573 });
574 anyhow::Ok(())
575 }
576 .log_err()
577 });
578 }
579
580 fn remaining_tokens(&self) -> Option<isize> {
581 Some(self.max_token_count as isize - self.token_count? as isize)
582 }
583
584 fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
585 self.model = model;
586 self.count_remaining_tokens(cx);
587 cx.notify();
588 }
589
590 fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<(Message, Message)> {
591 let request = OpenAIRequest {
592 model: self.model.clone(),
593 messages: self.open_ai_request_messages(cx),
594 stream: true,
595 };
596
597 let api_key = self.api_key.borrow().clone()?;
598 let stream = stream_completion(api_key, cx.background().clone(), request);
599 let assistant_message =
600 self.insert_message_after(self.messages.last()?.id, Role::Assistant, cx)?;
601 let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?;
602 let task = cx.spawn_weak({
603 |this, mut cx| async move {
604 let assistant_message_id = assistant_message.id;
605 let stream_completion = async {
606 let mut messages = stream.await?;
607
608 while let Some(message) = messages.next().await {
609 let mut message = message?;
610 if let Some(choice) = message.choices.pop() {
611 this.upgrade(&cx)
612 .ok_or_else(|| anyhow!("assistant was dropped"))?
613 .update(&mut cx, |this, cx| {
614 let text: Arc<str> = choice.delta.content?.into();
615 let message_ix = this
616 .messages
617 .iter()
618 .position(|message| message.id == assistant_message_id)?;
619 this.buffer.update(cx, |buffer, cx| {
620 let offset = if message_ix + 1 == this.messages.len() {
621 buffer.len()
622 } else {
623 this.messages[message_ix + 1]
624 .start
625 .to_offset(buffer)
626 .saturating_sub(1)
627 };
628 buffer.edit([(offset..offset, text)], None, cx);
629 });
630 cx.emit(AssistantEvent::StreamedCompletion);
631
632 Some(())
633 });
634 }
635 }
636
637 this.upgrade(&cx)
638 .ok_or_else(|| anyhow!("assistant was dropped"))?
639 .update(&mut cx, |this, cx| {
640 this.pending_completions
641 .retain(|completion| completion.id != this.completion_count);
642 this.summarize(cx);
643 });
644
645 anyhow::Ok(())
646 };
647
648 let result = stream_completion.await;
649 if let Some(this) = this.upgrade(&cx) {
650 this.update(&mut cx, |this, cx| {
651 if let Err(error) = result {
652 if let Some(metadata) =
653 this.messages_metadata.get_mut(&assistant_message.id)
654 {
655 metadata.error = Some(error.to_string().trim().into());
656 cx.notify();
657 }
658 }
659 });
660 }
661 }
662 });
663
664 self.pending_completions.push(PendingCompletion {
665 id: post_inc(&mut self.completion_count),
666 _task: task,
667 });
668 Some((assistant_message, user_message))
669 }
670
671 fn cancel_last_assist(&mut self) -> bool {
672 self.pending_completions.pop().is_some()
673 }
674
675 fn cycle_message_role(&mut self, id: MessageId, cx: &mut ModelContext<Self>) {
676 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
677 metadata.role.cycle();
678 cx.emit(AssistantEvent::MessagesEdited);
679 cx.notify();
680 }
681 }
682
683 fn insert_message_after(
684 &mut self,
685 message_id: MessageId,
686 role: Role,
687 cx: &mut ModelContext<Self>,
688 ) -> Option<Message> {
689 if let Some(prev_message_ix) = self
690 .messages
691 .iter()
692 .position(|message| message.id == message_id)
693 {
694 let start = self.buffer.update(cx, |buffer, cx| {
695 let offset = self.messages[prev_message_ix + 1..]
696 .iter()
697 .find(|message| message.start.is_valid(buffer))
698 .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
699 buffer.edit([(offset..offset, "\n")], None, cx);
700 buffer.anchor_before(offset + 1)
701 });
702 let message = Message {
703 id: MessageId(post_inc(&mut self.next_message_id.0)),
704 start,
705 };
706 self.messages.insert(prev_message_ix + 1, message.clone());
707 self.messages_metadata.insert(
708 message.id,
709 MessageMetadata {
710 role,
711 sent_at: Local::now(),
712 error: None,
713 },
714 );
715 cx.emit(AssistantEvent::MessagesEdited);
716 Some(message)
717 } else {
718 None
719 }
720 }
721
722 fn split_message(
723 &mut self,
724 range: Range<usize>,
725 cx: &mut ModelContext<Self>,
726 ) -> (Option<Message>, Option<Message>) {
727 let start_message = self.message_for_offset(range.start, cx);
728 let end_message = self.message_for_offset(range.end, cx);
729 if let Some((start_message, end_message)) = start_message.zip(end_message) {
730 let (start_message_ix, _, metadata, message_range) = start_message;
731 let (end_message_ix, _, _, _) = end_message;
732
733 // Prevent splitting when range spans multiple messages.
734 if start_message_ix != end_message_ix {
735 return (None, None);
736 }
737
738 let role = metadata.role;
739 let mut edited_buffer = false;
740
741 let mut suffix_start = None;
742 if range.start > message_range.start && range.end < message_range.end - 1 {
743 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
744 suffix_start = Some(range.end + 1);
745 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
746 suffix_start = Some(range.end);
747 }
748 }
749
750 let suffix = if let Some(suffix_start) = suffix_start {
751 Message {
752 id: MessageId(post_inc(&mut self.next_message_id.0)),
753 start: self.buffer.read(cx).anchor_before(suffix_start),
754 }
755 } else {
756 self.buffer.update(cx, |buffer, cx| {
757 buffer.edit([(range.end..range.end, "\n")], None, cx);
758 });
759 edited_buffer = true;
760 Message {
761 id: MessageId(post_inc(&mut self.next_message_id.0)),
762 start: self.buffer.read(cx).anchor_before(range.end + 1),
763 }
764 };
765
766 self.messages.insert(start_message_ix + 1, suffix.clone());
767 self.messages_metadata.insert(
768 suffix.id,
769 MessageMetadata {
770 role,
771 sent_at: Local::now(),
772 error: None,
773 },
774 );
775
776 let new_messages = if range.start == range.end || range.start == message_range.start {
777 (None, Some(suffix))
778 } else {
779 let mut prefix_end = None;
780 if range.start > message_range.start && range.end < message_range.end - 1 {
781 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
782 prefix_end = Some(range.start + 1);
783 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
784 == Some('\n')
785 {
786 prefix_end = Some(range.start);
787 }
788 }
789
790 let selection = if let Some(prefix_end) = prefix_end {
791 cx.emit(AssistantEvent::MessagesEdited);
792 Message {
793 id: MessageId(post_inc(&mut self.next_message_id.0)),
794 start: self.buffer.read(cx).anchor_before(prefix_end),
795 }
796 } else {
797 self.buffer.update(cx, |buffer, cx| {
798 buffer.edit([(range.start..range.start, "\n")], None, cx)
799 });
800 edited_buffer = true;
801 Message {
802 id: MessageId(post_inc(&mut self.next_message_id.0)),
803 start: self.buffer.read(cx).anchor_before(range.end + 1),
804 }
805 };
806
807 self.messages
808 .insert(start_message_ix + 1, selection.clone());
809 self.messages_metadata.insert(
810 selection.id,
811 MessageMetadata {
812 role,
813 sent_at: Local::now(),
814 error: None,
815 },
816 );
817 (Some(selection), Some(suffix))
818 };
819
820 if !edited_buffer {
821 cx.emit(AssistantEvent::MessagesEdited);
822 }
823 new_messages
824 } else {
825 (None, None)
826 }
827 }
828
829 fn summarize(&mut self, cx: &mut ModelContext<Self>) {
830 if self.messages.len() >= 2 && self.summary.is_none() {
831 let api_key = self.api_key.borrow().clone();
832 if let Some(api_key) = api_key {
833 let mut messages = self.open_ai_request_messages(cx);
834 messages.truncate(2);
835 messages.push(RequestMessage {
836 role: Role::User,
837 content: "Summarize the conversation into a short title without punctuation"
838 .into(),
839 });
840 let request = OpenAIRequest {
841 model: self.model.clone(),
842 messages,
843 stream: true,
844 };
845
846 let stream = stream_completion(api_key, cx.background().clone(), request);
847 self.pending_summary = cx.spawn(|this, mut cx| {
848 async move {
849 let mut messages = stream.await?;
850
851 while let Some(message) = messages.next().await {
852 let mut message = message?;
853 if let Some(choice) = message.choices.pop() {
854 let text = choice.delta.content.unwrap_or_default();
855 this.update(&mut cx, |this, cx| {
856 this.summary.get_or_insert(String::new()).push_str(&text);
857 cx.emit(AssistantEvent::SummaryChanged);
858 });
859 }
860 }
861
862 anyhow::Ok(())
863 }
864 .log_err()
865 });
866 }
867 }
868 }
869
870 fn open_ai_request_messages(&self, cx: &AppContext) -> Vec<RequestMessage> {
871 let buffer = self.buffer.read(cx);
872 self.messages(cx)
873 .map(|(_ix, _message, metadata, range)| RequestMessage {
874 role: metadata.role,
875 content: buffer.text_for_range(range).collect(),
876 })
877 .collect()
878 }
879
880 fn message_for_offset<'a>(
881 &'a self,
882 offset: usize,
883 cx: &'a AppContext,
884 ) -> Option<(usize, &Message, &MessageMetadata, Range<usize>)> {
885 let mut messages = self.messages(cx).peekable();
886 while let Some((ix, message, metadata, range)) = messages.next() {
887 if range.contains(&offset) || messages.peek().is_none() {
888 return Some((ix, message, metadata, range));
889 }
890 }
891 None
892 }
893
894 fn messages<'a>(
895 &'a self,
896 cx: &'a AppContext,
897 ) -> impl 'a + Iterator<Item = (usize, &Message, &MessageMetadata, Range<usize>)> {
898 let buffer = self.buffer.read(cx);
899 let mut messages = self.messages.iter().enumerate().peekable();
900 iter::from_fn(move || {
901 while let Some((ix, message)) = messages.next() {
902 let metadata = self.messages_metadata.get(&message.id)?;
903 let message_start = message.start.to_offset(buffer);
904 let mut message_end = None;
905 while let Some((_, next_message)) = messages.peek() {
906 if next_message.start.is_valid(buffer) {
907 message_end = Some(next_message.start);
908 break;
909 } else {
910 messages.next();
911 }
912 }
913 let message_end = message_end
914 .unwrap_or(language::Anchor::MAX)
915 .to_offset(buffer);
916 return Some((ix, message, metadata, message_start..message_end));
917 }
918 None
919 })
920 }
921}
922
923struct PendingCompletion {
924 id: usize,
925 _task: Task<()>,
926}
927
928enum AssistantEditorEvent {
929 TabContentChanged,
930}
931
932#[derive(Copy, Clone, Debug, PartialEq)]
933struct ScrollPosition {
934 offset_before_cursor: Vector2F,
935 cursor: Anchor,
936}
937
938struct AssistantEditor {
939 assistant: ModelHandle<Assistant>,
940 editor: ViewHandle<Editor>,
941 blocks: HashSet<BlockId>,
942 scroll_position: Option<ScrollPosition>,
943 _subscriptions: Vec<Subscription>,
944}
945
946impl AssistantEditor {
947 fn new(
948 api_key: Rc<RefCell<Option<String>>>,
949 language_registry: Arc<LanguageRegistry>,
950 cx: &mut ViewContext<Self>,
951 ) -> Self {
952 let assistant = cx.add_model(|cx| Assistant::new(api_key, language_registry, cx));
953 let editor = cx.add_view(|cx| {
954 let mut editor = Editor::for_buffer(assistant.read(cx).buffer.clone(), None, cx);
955 editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
956 editor.set_show_gutter(false, cx);
957 editor
958 });
959
960 let _subscriptions = vec![
961 cx.observe(&assistant, |_, _, cx| cx.notify()),
962 cx.subscribe(&assistant, Self::handle_assistant_event),
963 cx.subscribe(&editor, Self::handle_editor_event),
964 ];
965
966 let mut this = Self {
967 assistant,
968 editor,
969 blocks: Default::default(),
970 scroll_position: None,
971 _subscriptions,
972 };
973 this.update_message_headers(cx);
974 this
975 }
976
977 fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
978 let user_message = self.assistant.update(cx, |assistant, cx| {
979 let (_, user_message) = assistant.assist(cx)?;
980 Some(user_message)
981 });
982
983 if let Some(user_message) = user_message {
984 let cursor = user_message
985 .start
986 .to_offset(&self.assistant.read(cx).buffer.read(cx));
987 self.editor.update(cx, |editor, cx| {
988 editor.change_selections(
989 Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
990 cx,
991 |selections| selections.select_ranges([cursor..cursor]),
992 );
993 });
994 }
995 }
996
997 fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
998 if !self
999 .assistant
1000 .update(cx, |assistant, _| assistant.cancel_last_assist())
1001 {
1002 cx.propagate_action();
1003 }
1004 }
1005
1006 fn handle_assistant_event(
1007 &mut self,
1008 _: ModelHandle<Assistant>,
1009 event: &AssistantEvent,
1010 cx: &mut ViewContext<Self>,
1011 ) {
1012 match event {
1013 AssistantEvent::MessagesEdited => self.update_message_headers(cx),
1014 AssistantEvent::SummaryChanged => {
1015 cx.emit(AssistantEditorEvent::TabContentChanged);
1016 }
1017 AssistantEvent::StreamedCompletion => {
1018 self.editor.update(cx, |editor, cx| {
1019 if let Some(scroll_position) = self.scroll_position {
1020 let snapshot = editor.snapshot(cx);
1021 let cursor_point = scroll_position.cursor.to_display_point(&snapshot);
1022 let scroll_top =
1023 cursor_point.row() as f32 - scroll_position.offset_before_cursor.y();
1024 editor.set_scroll_position(
1025 vec2f(scroll_position.offset_before_cursor.x(), scroll_top),
1026 cx,
1027 );
1028 }
1029 });
1030 }
1031 }
1032 }
1033
1034 fn handle_editor_event(
1035 &mut self,
1036 _: ViewHandle<Editor>,
1037 event: &editor::Event,
1038 cx: &mut ViewContext<Self>,
1039 ) {
1040 match event {
1041 editor::Event::ScrollPositionChanged { autoscroll, .. } => {
1042 let cursor_scroll_position = self.cursor_scroll_position(cx);
1043 if *autoscroll {
1044 self.scroll_position = cursor_scroll_position;
1045 } else if self.scroll_position != cursor_scroll_position {
1046 self.scroll_position = None;
1047 }
1048 }
1049 editor::Event::SelectionsChanged { .. } => {
1050 self.scroll_position = self.cursor_scroll_position(cx);
1051 }
1052 _ => {}
1053 }
1054 }
1055
1056 fn cursor_scroll_position(&self, cx: &mut ViewContext<Self>) -> Option<ScrollPosition> {
1057 self.editor.update(cx, |editor, cx| {
1058 let snapshot = editor.snapshot(cx);
1059 let cursor = editor.selections.newest_anchor().head();
1060 let cursor_row = cursor.to_display_point(&snapshot.display_snapshot).row() as f32;
1061 let scroll_position = editor
1062 .scroll_manager
1063 .anchor()
1064 .scroll_position(&snapshot.display_snapshot);
1065
1066 let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1067 if (scroll_position.y()..scroll_bottom).contains(&cursor_row) {
1068 Some(ScrollPosition {
1069 cursor,
1070 offset_before_cursor: vec2f(
1071 scroll_position.x(),
1072 cursor_row - scroll_position.y(),
1073 ),
1074 })
1075 } else {
1076 None
1077 }
1078 })
1079 }
1080
1081 fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
1082 self.editor.update(cx, |editor, cx| {
1083 let buffer = editor.buffer().read(cx).snapshot(cx);
1084 let excerpt_id = *buffer.as_singleton().unwrap().0;
1085 let old_blocks = std::mem::take(&mut self.blocks);
1086 let new_blocks = self
1087 .assistant
1088 .read(cx)
1089 .messages(cx)
1090 .map(|(_, message, metadata, _)| BlockProperties {
1091 position: buffer.anchor_in_excerpt(excerpt_id, message.start),
1092 height: 2,
1093 style: BlockStyle::Sticky,
1094 render: Arc::new({
1095 let assistant = self.assistant.clone();
1096 let metadata = metadata.clone();
1097 let message = message.clone();
1098 move |cx| {
1099 enum Sender {}
1100 enum ErrorTooltip {}
1101
1102 let theme = theme::current(cx);
1103 let style = &theme.assistant;
1104 let message_id = message.id;
1105 let sender = MouseEventHandler::<Sender, _>::new(
1106 message_id.0,
1107 cx,
1108 |state, _| match metadata.role {
1109 Role::User => {
1110 let style = style.user_sender.style_for(state, false);
1111 Label::new("You", style.text.clone())
1112 .contained()
1113 .with_style(style.container)
1114 }
1115 Role::Assistant => {
1116 let style = style.assistant_sender.style_for(state, false);
1117 Label::new("Assistant", style.text.clone())
1118 .contained()
1119 .with_style(style.container)
1120 }
1121 Role::System => {
1122 let style = style.system_sender.style_for(state, false);
1123 Label::new("System", style.text.clone())
1124 .contained()
1125 .with_style(style.container)
1126 }
1127 },
1128 )
1129 .with_cursor_style(CursorStyle::PointingHand)
1130 .on_down(MouseButton::Left, {
1131 let assistant = assistant.clone();
1132 move |_, _, cx| {
1133 assistant.update(cx, |assistant, cx| {
1134 assistant.cycle_message_role(message_id, cx)
1135 })
1136 }
1137 });
1138
1139 Flex::row()
1140 .with_child(sender.aligned())
1141 .with_child(
1142 Label::new(
1143 metadata.sent_at.format("%I:%M%P").to_string(),
1144 style.sent_at.text.clone(),
1145 )
1146 .contained()
1147 .with_style(style.sent_at.container)
1148 .aligned(),
1149 )
1150 .with_children(metadata.error.clone().map(|error| {
1151 Svg::new("icons/circle_x_mark_12.svg")
1152 .with_color(style.error_icon.color)
1153 .constrained()
1154 .with_width(style.error_icon.width)
1155 .contained()
1156 .with_style(style.error_icon.container)
1157 .with_tooltip::<ErrorTooltip>(
1158 message_id.0,
1159 error,
1160 None,
1161 theme.tooltip.clone(),
1162 cx,
1163 )
1164 .aligned()
1165 }))
1166 .aligned()
1167 .left()
1168 .contained()
1169 .with_style(style.header)
1170 .into_any()
1171 }
1172 }),
1173 disposition: BlockDisposition::Above,
1174 })
1175 .collect::<Vec<_>>();
1176
1177 editor.remove_blocks(old_blocks, None, cx);
1178 let ids = editor.insert_blocks(new_blocks, None, cx);
1179 self.blocks = HashSet::from_iter(ids);
1180 });
1181 }
1182
1183 fn quote_selection(
1184 workspace: &mut Workspace,
1185 _: &QuoteSelection,
1186 cx: &mut ViewContext<Workspace>,
1187 ) {
1188 let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1189 return;
1190 };
1191 let Some(editor) = workspace.active_item(cx).and_then(|item| item.downcast::<Editor>()) else {
1192 return;
1193 };
1194
1195 let text = editor.read_with(cx, |editor, cx| {
1196 let range = editor.selections.newest::<usize>(cx).range();
1197 let buffer = editor.buffer().read(cx).snapshot(cx);
1198 let start_language = buffer.language_at(range.start);
1199 let end_language = buffer.language_at(range.end);
1200 let language_name = if start_language == end_language {
1201 start_language.map(|language| language.name())
1202 } else {
1203 None
1204 };
1205 let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1206
1207 let selected_text = buffer.text_for_range(range).collect::<String>();
1208 if selected_text.is_empty() {
1209 None
1210 } else {
1211 Some(if language_name == "markdown" {
1212 selected_text
1213 .lines()
1214 .map(|line| format!("> {}", line))
1215 .collect::<Vec<_>>()
1216 .join("\n")
1217 } else {
1218 format!("```{language_name}\n{selected_text}\n```")
1219 })
1220 }
1221 });
1222
1223 // Activate the panel
1224 if !panel.read(cx).has_focus(cx) {
1225 workspace.toggle_panel_focus::<AssistantPanel>(cx);
1226 }
1227
1228 if let Some(text) = text {
1229 panel.update(cx, |panel, cx| {
1230 if let Some(assistant) = panel
1231 .pane
1232 .read(cx)
1233 .active_item()
1234 .and_then(|item| item.downcast::<AssistantEditor>())
1235 .ok_or_else(|| anyhow!("no active context"))
1236 .log_err()
1237 {
1238 assistant.update(cx, |assistant, cx| {
1239 assistant
1240 .editor
1241 .update(cx, |editor, cx| editor.insert(&text, cx))
1242 });
1243 }
1244 });
1245 }
1246 }
1247
1248 fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1249 let editor = self.editor.read(cx);
1250 let assistant = self.assistant.read(cx);
1251 if editor.selections.count() == 1 {
1252 let selection = editor.selections.newest::<usize>(cx);
1253 let mut copied_text = String::new();
1254 let mut spanned_messages = 0;
1255 for (_ix, _message, metadata, message_range) in assistant.messages(cx) {
1256 if message_range.start >= selection.range().end {
1257 break;
1258 } else if message_range.end >= selection.range().start {
1259 let range = cmp::max(message_range.start, selection.range().start)
1260 ..cmp::min(message_range.end, selection.range().end);
1261 if !range.is_empty() {
1262 spanned_messages += 1;
1263 write!(&mut copied_text, "## {}\n\n", metadata.role).unwrap();
1264 for chunk in assistant.buffer.read(cx).text_for_range(range) {
1265 copied_text.push_str(&chunk);
1266 }
1267 copied_text.push('\n');
1268 }
1269 }
1270 }
1271
1272 if spanned_messages > 1 {
1273 cx.platform()
1274 .write_to_clipboard(ClipboardItem::new(copied_text));
1275 return;
1276 }
1277 }
1278
1279 cx.propagate_action();
1280 }
1281
1282 fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
1283 self.assistant.update(cx, |assistant, cx| {
1284 let range = self.editor.read(cx).selections.newest::<usize>(cx).range();
1285 assistant.split_message(range, cx);
1286 });
1287 }
1288
1289 fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
1290 self.assistant.update(cx, |assistant, cx| {
1291 let new_model = match assistant.model.as_str() {
1292 "gpt-4" => "gpt-3.5-turbo",
1293 _ => "gpt-4",
1294 };
1295 assistant.set_model(new_model.into(), cx);
1296 });
1297 }
1298
1299 fn title(&self, cx: &AppContext) -> String {
1300 self.assistant
1301 .read(cx)
1302 .summary
1303 .clone()
1304 .unwrap_or_else(|| "New Context".into())
1305 }
1306}
1307
1308impl Entity for AssistantEditor {
1309 type Event = AssistantEditorEvent;
1310}
1311
1312impl View for AssistantEditor {
1313 fn ui_name() -> &'static str {
1314 "AssistantEditor"
1315 }
1316
1317 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
1318 enum Model {}
1319 let theme = &theme::current(cx).assistant;
1320 let assistant = &self.assistant.read(cx);
1321 let model = assistant.model.clone();
1322 let remaining_tokens = assistant.remaining_tokens().map(|remaining_tokens| {
1323 let remaining_tokens_style = if remaining_tokens <= 0 {
1324 &theme.no_remaining_tokens
1325 } else {
1326 &theme.remaining_tokens
1327 };
1328 Label::new(
1329 remaining_tokens.to_string(),
1330 remaining_tokens_style.text.clone(),
1331 )
1332 .contained()
1333 .with_style(remaining_tokens_style.container)
1334 });
1335
1336 Stack::new()
1337 .with_child(
1338 ChildView::new(&self.editor, cx)
1339 .contained()
1340 .with_style(theme.container),
1341 )
1342 .with_child(
1343 Flex::row()
1344 .with_child(
1345 MouseEventHandler::<Model, _>::new(0, cx, |state, _| {
1346 let style = theme.model.style_for(state, false);
1347 Label::new(model, style.text.clone())
1348 .contained()
1349 .with_style(style.container)
1350 })
1351 .with_cursor_style(CursorStyle::PointingHand)
1352 .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx)),
1353 )
1354 .with_children(remaining_tokens)
1355 .contained()
1356 .with_style(theme.model_info_container)
1357 .aligned()
1358 .top()
1359 .right(),
1360 )
1361 .into_any()
1362 }
1363
1364 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
1365 if cx.is_self_focused() {
1366 cx.focus(&self.editor);
1367 }
1368 }
1369}
1370
1371impl Item for AssistantEditor {
1372 fn tab_content<V: View>(
1373 &self,
1374 _: Option<usize>,
1375 style: &theme::Tab,
1376 cx: &gpui::AppContext,
1377 ) -> AnyElement<V> {
1378 let title = truncate_and_trailoff(&self.title(cx), editor::MAX_TAB_TITLE_LEN);
1379 Label::new(title, style.label.clone()).into_any()
1380 }
1381
1382 fn tab_tooltip_text(&self, cx: &AppContext) -> Option<Cow<str>> {
1383 Some(self.title(cx).into())
1384 }
1385
1386 fn as_searchable(
1387 &self,
1388 _: &ViewHandle<Self>,
1389 ) -> Option<Box<dyn workspace::searchable::SearchableItemHandle>> {
1390 Some(Box::new(self.editor.clone()))
1391 }
1392}
1393
1394#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash)]
1395struct MessageId(usize);
1396
1397#[derive(Clone, Debug)]
1398struct Message {
1399 id: MessageId,
1400 start: language::Anchor,
1401}
1402
1403#[derive(Clone, Debug)]
1404struct MessageMetadata {
1405 role: Role,
1406 sent_at: DateTime<Local>,
1407 error: Option<String>,
1408}
1409
1410async fn stream_completion(
1411 api_key: String,
1412 executor: Arc<Background>,
1413 mut request: OpenAIRequest,
1414) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
1415 request.stream = true;
1416
1417 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
1418
1419 let json_data = serde_json::to_string(&request)?;
1420 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
1421 .header("Content-Type", "application/json")
1422 .header("Authorization", format!("Bearer {}", api_key))
1423 .body(json_data)?
1424 .send_async()
1425 .await?;
1426
1427 let status = response.status();
1428 if status == StatusCode::OK {
1429 executor
1430 .spawn(async move {
1431 let mut lines = BufReader::new(response.body_mut()).lines();
1432
1433 fn parse_line(
1434 line: Result<String, io::Error>,
1435 ) -> Result<Option<OpenAIResponseStreamEvent>> {
1436 if let Some(data) = line?.strip_prefix("data: ") {
1437 let event = serde_json::from_str(&data)?;
1438 Ok(Some(event))
1439 } else {
1440 Ok(None)
1441 }
1442 }
1443
1444 while let Some(line) = lines.next().await {
1445 if let Some(event) = parse_line(line).transpose() {
1446 let done = event.as_ref().map_or(false, |event| {
1447 event
1448 .choices
1449 .last()
1450 .map_or(false, |choice| choice.finish_reason.is_some())
1451 });
1452 if tx.unbounded_send(event).is_err() {
1453 break;
1454 }
1455
1456 if done {
1457 break;
1458 }
1459 }
1460 }
1461
1462 anyhow::Ok(())
1463 })
1464 .detach();
1465
1466 Ok(rx)
1467 } else {
1468 let mut body = String::new();
1469 response.body_mut().read_to_string(&mut body).await?;
1470
1471 #[derive(Deserialize)]
1472 struct OpenAIResponse {
1473 error: OpenAIError,
1474 }
1475
1476 #[derive(Deserialize)]
1477 struct OpenAIError {
1478 message: String,
1479 }
1480
1481 match serde_json::from_str::<OpenAIResponse>(&body) {
1482 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
1483 "Failed to connect to OpenAI API: {}",
1484 response.error.message,
1485 )),
1486
1487 _ => Err(anyhow!(
1488 "Failed to connect to OpenAI API: {} {}",
1489 response.status(),
1490 body,
1491 )),
1492 }
1493 }
1494}
1495
1496#[cfg(test)]
1497mod tests {
1498 use super::*;
1499 use gpui::AppContext;
1500
1501 #[gpui::test]
1502 fn test_inserting_and_removing_messages(cx: &mut AppContext) {
1503 let registry = Arc::new(LanguageRegistry::test());
1504 let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
1505 let buffer = assistant.read(cx).buffer.clone();
1506
1507 let message_1 = assistant.read(cx).messages[0].clone();
1508 assert_eq!(
1509 messages(&assistant, cx),
1510 vec![(message_1.id, Role::User, 0..0)]
1511 );
1512
1513 let message_2 = assistant.update(cx, |assistant, cx| {
1514 assistant
1515 .insert_message_after(message_1.id, Role::Assistant, cx)
1516 .unwrap()
1517 });
1518 assert_eq!(
1519 messages(&assistant, cx),
1520 vec![
1521 (message_1.id, Role::User, 0..1),
1522 (message_2.id, Role::Assistant, 1..1)
1523 ]
1524 );
1525
1526 buffer.update(cx, |buffer, cx| {
1527 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
1528 });
1529 assert_eq!(
1530 messages(&assistant, cx),
1531 vec![
1532 (message_1.id, Role::User, 0..2),
1533 (message_2.id, Role::Assistant, 2..3)
1534 ]
1535 );
1536
1537 let message_3 = assistant.update(cx, |assistant, cx| {
1538 assistant
1539 .insert_message_after(message_2.id, Role::User, cx)
1540 .unwrap()
1541 });
1542 assert_eq!(
1543 messages(&assistant, cx),
1544 vec![
1545 (message_1.id, Role::User, 0..2),
1546 (message_2.id, Role::Assistant, 2..4),
1547 (message_3.id, Role::User, 4..4)
1548 ]
1549 );
1550
1551 let message_4 = assistant.update(cx, |assistant, cx| {
1552 assistant
1553 .insert_message_after(message_2.id, Role::User, cx)
1554 .unwrap()
1555 });
1556 assert_eq!(
1557 messages(&assistant, cx),
1558 vec![
1559 (message_1.id, Role::User, 0..2),
1560 (message_2.id, Role::Assistant, 2..4),
1561 (message_4.id, Role::User, 4..5),
1562 (message_3.id, Role::User, 5..5),
1563 ]
1564 );
1565
1566 buffer.update(cx, |buffer, cx| {
1567 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
1568 });
1569 assert_eq!(
1570 messages(&assistant, cx),
1571 vec![
1572 (message_1.id, Role::User, 0..2),
1573 (message_2.id, Role::Assistant, 2..4),
1574 (message_4.id, Role::User, 4..6),
1575 (message_3.id, Role::User, 6..7),
1576 ]
1577 );
1578
1579 // Deleting across message boundaries merges the messages.
1580 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
1581 assert_eq!(
1582 messages(&assistant, cx),
1583 vec![
1584 (message_1.id, Role::User, 0..3),
1585 (message_3.id, Role::User, 3..4),
1586 ]
1587 );
1588
1589 // Undoing the deletion should also undo the merge.
1590 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1591 assert_eq!(
1592 messages(&assistant, cx),
1593 vec![
1594 (message_1.id, Role::User, 0..2),
1595 (message_2.id, Role::Assistant, 2..4),
1596 (message_4.id, Role::User, 4..6),
1597 (message_3.id, Role::User, 6..7),
1598 ]
1599 );
1600
1601 // Redoing the deletion should also redo the merge.
1602 buffer.update(cx, |buffer, cx| buffer.redo(cx));
1603 assert_eq!(
1604 messages(&assistant, cx),
1605 vec![
1606 (message_1.id, Role::User, 0..3),
1607 (message_3.id, Role::User, 3..4),
1608 ]
1609 );
1610
1611 // Ensure we can still insert after a merged message.
1612 let message_5 = assistant.update(cx, |assistant, cx| {
1613 assistant
1614 .insert_message_after(message_1.id, Role::System, cx)
1615 .unwrap()
1616 });
1617 assert_eq!(
1618 messages(&assistant, cx),
1619 vec![
1620 (message_1.id, Role::User, 0..3),
1621 (message_5.id, Role::System, 3..4),
1622 (message_3.id, Role::User, 4..5)
1623 ]
1624 );
1625 }
1626
1627 #[gpui::test]
1628 fn test_message_splitting(cx: &mut AppContext) {
1629 let registry = Arc::new(LanguageRegistry::test());
1630 let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
1631 let buffer = assistant.read(cx).buffer.clone();
1632
1633 let message_1 = assistant.read(cx).messages[0].clone();
1634 assert_eq!(
1635 messages(&assistant, cx),
1636 vec![(message_1.id, Role::User, 0..0)]
1637 );
1638
1639 buffer.update(cx, |buffer, cx| {
1640 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
1641 });
1642
1643 let (_, message_2) =
1644 assistant.update(cx, |assistant, cx| assistant.split_message(3..3, cx));
1645 let message_2 = message_2.unwrap();
1646
1647 // We recycle newlines in the middle of a split message
1648 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
1649 assert_eq!(
1650 messages(&assistant, cx),
1651 vec![
1652 (message_1.id, Role::User, 0..4),
1653 (message_2.id, Role::User, 4..16),
1654 ]
1655 );
1656
1657 let (_, message_3) =
1658 assistant.update(cx, |assistant, cx| assistant.split_message(3..3, cx));
1659 let message_3 = message_3.unwrap();
1660
1661 // We don't recycle newlines at the end of a split message
1662 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
1663 assert_eq!(
1664 messages(&assistant, cx),
1665 vec![
1666 (message_1.id, Role::User, 0..4),
1667 (message_3.id, Role::User, 4..5),
1668 (message_2.id, Role::User, 5..17),
1669 ]
1670 );
1671
1672 let (_, message_4) =
1673 assistant.update(cx, |assistant, cx| assistant.split_message(9..9, cx));
1674 let message_4 = message_4.unwrap();
1675 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
1676 assert_eq!(
1677 messages(&assistant, cx),
1678 vec![
1679 (message_1.id, Role::User, 0..4),
1680 (message_3.id, Role::User, 4..5),
1681 (message_2.id, Role::User, 5..9),
1682 (message_4.id, Role::User, 9..17),
1683 ]
1684 );
1685
1686 let (_, message_5) =
1687 assistant.update(cx, |assistant, cx| assistant.split_message(9..9, cx));
1688 let message_5 = message_5.unwrap();
1689 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
1690 assert_eq!(
1691 messages(&assistant, cx),
1692 vec![
1693 (message_1.id, Role::User, 0..4),
1694 (message_3.id, Role::User, 4..5),
1695 (message_2.id, Role::User, 5..9),
1696 (message_4.id, Role::User, 9..10),
1697 (message_5.id, Role::User, 10..18),
1698 ]
1699 );
1700
1701 let (message_6, message_7) =
1702 assistant.update(cx, |assistant, cx| assistant.split_message(14..16, cx));
1703 let message_6 = message_6.unwrap();
1704 let message_7 = message_7.unwrap();
1705 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
1706 assert_eq!(
1707 messages(&assistant, cx),
1708 vec![
1709 (message_1.id, Role::User, 0..4),
1710 (message_3.id, Role::User, 4..5),
1711 (message_2.id, Role::User, 5..9),
1712 (message_4.id, Role::User, 9..10),
1713 (message_5.id, Role::User, 10..14),
1714 (message_6.id, Role::User, 14..17),
1715 (message_7.id, Role::User, 17..19),
1716 ]
1717 );
1718 }
1719
1720 fn messages(
1721 assistant: &ModelHandle<Assistant>,
1722 cx: &AppContext,
1723 ) -> Vec<(MessageId, Role, Range<usize>)> {
1724 assistant
1725 .read(cx)
1726 .messages(cx)
1727 .map(|(_, message, metadata, range)| (message.id, metadata.role, range))
1728 .collect()
1729 }
1730}