message_editor.rs

  1use editor::{Editor, EditorElement, EditorStyle};
  2use feature_flags::{FeatureFlagAppExt, ToolUseFeatureFlag};
  3use gpui::{AppContext, FocusableView, Model, TextStyle, View};
  4use language_model::{LanguageModelRegistry, LanguageModelRequestTool};
  5use settings::Settings;
  6use theme::ThemeSettings;
  7use ui::{prelude::*, ButtonLike, ElevationIndex, KeyBinding};
  8
  9use crate::thread::{RequestKind, Thread};
 10use crate::Chat;
 11
 12pub struct MessageEditor {
 13    thread: Model<Thread>,
 14    editor: View<Editor>,
 15}
 16
 17impl MessageEditor {
 18    pub fn new(thread: Model<Thread>, cx: &mut ViewContext<Self>) -> Self {
 19        Self {
 20            thread,
 21            editor: cx.new_view(|cx| {
 22                let mut editor = Editor::auto_height(80, cx);
 23                editor.set_placeholder_text("Ask anything…", cx);
 24
 25                editor
 26            }),
 27        }
 28    }
 29
 30    fn chat(&mut self, _: &Chat, cx: &mut ViewContext<Self>) {
 31        self.send_to_model(RequestKind::Chat, cx);
 32    }
 33
 34    fn send_to_model(
 35        &mut self,
 36        request_kind: RequestKind,
 37        cx: &mut ViewContext<Self>,
 38    ) -> Option<()> {
 39        let provider = LanguageModelRegistry::read_global(cx).active_provider();
 40        if provider
 41            .as_ref()
 42            .map_or(false, |provider| provider.must_accept_terms(cx))
 43        {
 44            cx.notify();
 45            return None;
 46        }
 47
 48        let model_registry = LanguageModelRegistry::read_global(cx);
 49        let model = model_registry.active_model()?;
 50
 51        let user_message = self.editor.update(cx, |editor, cx| {
 52            let text = editor.text(cx);
 53            editor.clear(cx);
 54            text
 55        });
 56
 57        self.thread.update(cx, |thread, cx| {
 58            thread.insert_user_message(user_message);
 59            let mut request = thread.to_completion_request(request_kind, cx);
 60
 61            if cx.has_flag::<ToolUseFeatureFlag>() {
 62                request.tools = thread
 63                    .tools()
 64                    .tools(cx)
 65                    .into_iter()
 66                    .map(|tool| LanguageModelRequestTool {
 67                        name: tool.name(),
 68                        description: tool.description(),
 69                        input_schema: tool.input_schema(),
 70                    })
 71                    .collect();
 72            }
 73
 74            thread.stream_completion(request, model, cx)
 75        });
 76
 77        None
 78    }
 79}
 80
 81impl FocusableView for MessageEditor {
 82    fn focus_handle(&self, cx: &AppContext) -> gpui::FocusHandle {
 83        self.editor.focus_handle(cx)
 84    }
 85}
 86
 87impl Render for MessageEditor {
 88    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
 89        let font_size = TextSize::Default.rems(cx);
 90        let line_height = font_size.to_pixels(cx.rem_size()) * 1.3;
 91        let focus_handle = self.editor.focus_handle(cx);
 92
 93        v_flex()
 94            .key_context("MessageEditor")
 95            .on_action(cx.listener(Self::chat))
 96            .size_full()
 97            .gap_2()
 98            .p_2()
 99            .bg(cx.theme().colors().editor_background)
100            .child({
101                let settings = ThemeSettings::get_global(cx);
102                let text_style = TextStyle {
103                    color: cx.theme().colors().editor_foreground,
104                    font_family: settings.ui_font.family.clone(),
105                    font_features: settings.ui_font.features.clone(),
106                    font_size: font_size.into(),
107                    font_weight: settings.ui_font.weight,
108                    line_height: line_height.into(),
109                    ..Default::default()
110                };
111
112                EditorElement::new(
113                    &self.editor,
114                    EditorStyle {
115                        background: cx.theme().colors().editor_background,
116                        local_player: cx.theme().players().local(),
117                        text: text_style,
118                        ..Default::default()
119                    },
120                )
121            })
122            .child(
123                h_flex()
124                    .justify_between()
125                    .child(
126                        h_flex().child(
127                            Button::new("add-context", "Add Context")
128                                .style(ButtonStyle::Filled)
129                                .icon(IconName::Plus)
130                                .icon_position(IconPosition::Start),
131                        ),
132                    )
133                    .child(
134                        h_flex()
135                            .gap_2()
136                            .child(Button::new("codebase", "Codebase").style(ButtonStyle::Filled))
137                            .child(Label::new("or"))
138                            .child(
139                                ButtonLike::new("chat")
140                                    .style(ButtonStyle::Filled)
141                                    .layer(ElevationIndex::ModalSurface)
142                                    .child(Label::new("Chat"))
143                                    .children(
144                                        KeyBinding::for_action_in(&Chat, &focus_handle, cx)
145                                            .map(|binding| binding.into_any_element()),
146                                    )
147                                    .on_click(move |_event, cx| {
148                                        focus_handle.dispatch_action(&Chat, cx);
149                                    }),
150                            ),
151                    ),
152            )
153    }
154}