message_editor.rs

  1use editor::{Editor, EditorElement, EditorStyle};
  2use futures::StreamExt;
  3use gpui::{AppContext, Model, TextStyle, View};
  4use language_model::{
  5    LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
  6    LanguageModelRequestMessage, MessageContent, Role, StopReason,
  7};
  8use settings::Settings;
  9use theme::ThemeSettings;
 10use ui::{prelude::*, ButtonLike, ElevationIndex, KeyBinding};
 11use util::ResultExt;
 12
 13use crate::thread::{self, Thread};
 14use crate::Chat;
 15
 16#[derive(Debug, Clone, Copy)]
 17pub enum RequestKind {
 18    Chat,
 19}
 20
 21pub struct MessageEditor {
 22    thread: Model<Thread>,
 23    editor: View<Editor>,
 24}
 25
 26impl MessageEditor {
 27    pub fn new(thread: Model<Thread>, cx: &mut ViewContext<Self>) -> Self {
 28        Self {
 29            thread,
 30            editor: cx.new_view(|cx| {
 31                let mut editor = Editor::auto_height(80, cx);
 32                editor.set_placeholder_text("Ask anything…", cx);
 33
 34                editor
 35            }),
 36        }
 37    }
 38
 39    fn chat(&mut self, _: &Chat, cx: &mut ViewContext<Self>) {
 40        self.send_to_model(RequestKind::Chat, cx);
 41    }
 42
 43    fn send_to_model(
 44        &mut self,
 45        request_kind: RequestKind,
 46        cx: &mut ViewContext<Self>,
 47    ) -> Option<()> {
 48        let provider = LanguageModelRegistry::read_global(cx).active_provider();
 49        if provider
 50            .as_ref()
 51            .map_or(false, |provider| provider.must_accept_terms(cx))
 52        {
 53            cx.notify();
 54            return None;
 55        }
 56
 57        let model_registry = LanguageModelRegistry::read_global(cx);
 58        let model = model_registry.active_model()?;
 59
 60        let request = self.build_completion_request(request_kind, cx);
 61
 62        let user_message = self.editor.read(cx).text(cx);
 63        self.thread.update(cx, |thread, _cx| {
 64            thread.messages.push(thread::Message {
 65                role: Role::User,
 66                text: user_message,
 67            });
 68        });
 69
 70        self.editor.update(cx, |editor, cx| {
 71            editor.clear(cx);
 72        });
 73
 74        let task = cx.spawn(|this, mut cx| async move {
 75            let stream = model.stream_completion(request, &cx);
 76            let stream_completion = async {
 77                let mut events = stream.await?;
 78                let mut stop_reason = StopReason::EndTurn;
 79
 80                let mut text = String::new();
 81
 82                while let Some(event) = events.next().await {
 83                    let event = event?;
 84                    match event {
 85                        LanguageModelCompletionEvent::StartMessage { .. } => {}
 86                        LanguageModelCompletionEvent::Stop(reason) => {
 87                            stop_reason = reason;
 88                        }
 89                        LanguageModelCompletionEvent::Text(chunk) => {
 90                            text.push_str(&chunk);
 91                        }
 92                        LanguageModelCompletionEvent::ToolUse(_tool_use) => {}
 93                    }
 94
 95                    smol::future::yield_now().await;
 96                }
 97
 98                anyhow::Ok((stop_reason, text))
 99            };
100
101            let result = stream_completion.await;
102
103            this.update(&mut cx, |this, cx| {
104                if let Some((_stop_reason, text)) = result.log_err() {
105                    this.thread.update(cx, |thread, _cx| {
106                        thread.messages.push(thread::Message {
107                            role: Role::Assistant,
108                            text,
109                        });
110                    });
111                }
112            })
113            .ok();
114        });
115
116        self.thread.update(cx, |thread, _cx| {
117            thread.pending_completion_tasks.push(task);
118        });
119
120        None
121    }
122
123    fn build_completion_request(
124        &self,
125        _request_kind: RequestKind,
126        cx: &AppContext,
127    ) -> LanguageModelRequest {
128        let text = self.editor.read(cx).text(cx);
129
130        let request = LanguageModelRequest {
131            messages: vec![LanguageModelRequestMessage {
132                role: Role::User,
133                content: vec![MessageContent::Text(text)],
134                cache: false,
135            }],
136            tools: Vec::new(),
137            stop: Vec::new(),
138            temperature: None,
139        };
140
141        request
142    }
143}
144
145impl Render for MessageEditor {
146    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
147        let font_size = TextSize::Default.rems(cx);
148        let line_height = font_size.to_pixels(cx.rem_size()) * 1.3;
149        let focus_handle = self.editor.focus_handle(cx);
150
151        v_flex()
152            .key_context("MessageEditor")
153            .on_action(cx.listener(Self::chat))
154            .size_full()
155            .gap_2()
156            .p_2()
157            .bg(cx.theme().colors().editor_background)
158            .child({
159                let settings = ThemeSettings::get_global(cx);
160                let text_style = TextStyle {
161                    color: cx.theme().colors().editor_foreground,
162                    font_family: settings.ui_font.family.clone(),
163                    font_features: settings.ui_font.features.clone(),
164                    font_size: font_size.into(),
165                    font_weight: settings.ui_font.weight,
166                    line_height: line_height.into(),
167                    ..Default::default()
168                };
169
170                EditorElement::new(
171                    &self.editor,
172                    EditorStyle {
173                        background: cx.theme().colors().editor_background,
174                        local_player: cx.theme().players().local(),
175                        text: text_style,
176                        ..Default::default()
177                    },
178                )
179            })
180            .child(
181                h_flex()
182                    .justify_between()
183                    .child(
184                        h_flex().child(
185                            Button::new("add-context", "Add Context")
186                                .style(ButtonStyle::Filled)
187                                .icon(IconName::Plus)
188                                .icon_position(IconPosition::Start),
189                        ),
190                    )
191                    .child(
192                        h_flex()
193                            .gap_2()
194                            .child(Button::new("codebase", "Codebase").style(ButtonStyle::Filled))
195                            .child(Label::new("or"))
196                            .child(
197                                ButtonLike::new("chat")
198                                    .style(ButtonStyle::Filled)
199                                    .layer(ElevationIndex::ModalSurface)
200                                    .child(Label::new("Chat"))
201                                    .children(
202                                        KeyBinding::for_action_in(&Chat, &focus_handle, cx)
203                                            .map(|binding| binding.into_any_element()),
204                                    )
205                                    .on_click(move |_event, cx| {
206                                        focus_handle.dispatch_action(&Chat, cx);
207                                    }),
208                            ),
209                    ),
210            )
211    }
212}