message_editor.rs

  1use editor::{Editor, EditorElement, EditorStyle};
  2use gpui::{AppContext, Model, TextStyle, View};
  3use language_model::{
  4    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
  5};
  6use settings::Settings;
  7use theme::ThemeSettings;
  8use ui::{prelude::*, ButtonLike, ElevationIndex, KeyBinding};
  9
 10use crate::thread::{self, Thread};
 11use crate::Chat;
 12
 13#[derive(Debug, Clone, Copy)]
 14pub enum RequestKind {
 15    Chat,
 16}
 17
 18pub struct MessageEditor {
 19    thread: Model<Thread>,
 20    editor: View<Editor>,
 21}
 22
 23impl MessageEditor {
 24    pub fn new(thread: Model<Thread>, cx: &mut ViewContext<Self>) -> Self {
 25        Self {
 26            thread,
 27            editor: cx.new_view(|cx| {
 28                let mut editor = Editor::auto_height(80, cx);
 29                editor.set_placeholder_text("Ask anything…", cx);
 30
 31                editor
 32            }),
 33        }
 34    }
 35
 36    fn chat(&mut self, _: &Chat, cx: &mut ViewContext<Self>) {
 37        self.send_to_model(RequestKind::Chat, cx);
 38    }
 39
 40    fn send_to_model(
 41        &mut self,
 42        request_kind: RequestKind,
 43        cx: &mut ViewContext<Self>,
 44    ) -> Option<()> {
 45        let provider = LanguageModelRegistry::read_global(cx).active_provider();
 46        if provider
 47            .as_ref()
 48            .map_or(false, |provider| provider.must_accept_terms(cx))
 49        {
 50            cx.notify();
 51            return None;
 52        }
 53
 54        let model_registry = LanguageModelRegistry::read_global(cx);
 55        let model = model_registry.active_model()?;
 56
 57        let request = self.build_completion_request(request_kind, cx);
 58
 59        let user_message = self.editor.read(cx).text(cx);
 60        self.thread.update(cx, |thread, _cx| {
 61            thread.messages.push(thread::Message {
 62                role: Role::User,
 63                text: user_message,
 64            });
 65        });
 66
 67        self.editor.update(cx, |editor, cx| {
 68            editor.clear(cx);
 69        });
 70
 71        self.thread.update(cx, |thread, cx| {
 72            thread.stream_completion(request, model, cx)
 73        });
 74
 75        None
 76    }
 77
 78    fn build_completion_request(
 79        &self,
 80        _request_kind: RequestKind,
 81        cx: &AppContext,
 82    ) -> LanguageModelRequest {
 83        let text = self.editor.read(cx).text(cx);
 84
 85        let request = LanguageModelRequest {
 86            messages: vec![LanguageModelRequestMessage {
 87                role: Role::User,
 88                content: vec![MessageContent::Text(text)],
 89                cache: false,
 90            }],
 91            tools: Vec::new(),
 92            stop: Vec::new(),
 93            temperature: None,
 94        };
 95
 96        request
 97    }
 98}
 99
100impl Render for MessageEditor {
101    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
102        let font_size = TextSize::Default.rems(cx);
103        let line_height = font_size.to_pixels(cx.rem_size()) * 1.3;
104        let focus_handle = self.editor.focus_handle(cx);
105
106        v_flex()
107            .key_context("MessageEditor")
108            .on_action(cx.listener(Self::chat))
109            .size_full()
110            .gap_2()
111            .p_2()
112            .bg(cx.theme().colors().editor_background)
113            .child({
114                let settings = ThemeSettings::get_global(cx);
115                let text_style = TextStyle {
116                    color: cx.theme().colors().editor_foreground,
117                    font_family: settings.ui_font.family.clone(),
118                    font_features: settings.ui_font.features.clone(),
119                    font_size: font_size.into(),
120                    font_weight: settings.ui_font.weight,
121                    line_height: line_height.into(),
122                    ..Default::default()
123                };
124
125                EditorElement::new(
126                    &self.editor,
127                    EditorStyle {
128                        background: cx.theme().colors().editor_background,
129                        local_player: cx.theme().players().local(),
130                        text: text_style,
131                        ..Default::default()
132                    },
133                )
134            })
135            .child(
136                h_flex()
137                    .justify_between()
138                    .child(
139                        h_flex().child(
140                            Button::new("add-context", "Add Context")
141                                .style(ButtonStyle::Filled)
142                                .icon(IconName::Plus)
143                                .icon_position(IconPosition::Start),
144                        ),
145                    )
146                    .child(
147                        h_flex()
148                            .gap_2()
149                            .child(Button::new("codebase", "Codebase").style(ButtonStyle::Filled))
150                            .child(Label::new("or"))
151                            .child(
152                                ButtonLike::new("chat")
153                                    .style(ButtonStyle::Filled)
154                                    .layer(ElevationIndex::ModalSurface)
155                                    .child(Label::new("Chat"))
156                                    .children(
157                                        KeyBinding::for_action_in(&Chat, &focus_handle, cx)
158                                            .map(|binding| binding.into_any_element()),
159                                    )
160                                    .on_click(move |_event, cx| {
161                                        focus_handle.dispatch_action(&Chat, cx);
162                                    }),
163                            ),
164                    ),
165            )
166    }
167}