active_thread.rs

  1use std::sync::Arc;
  2
  3use assistant_tool::ToolWorkingSet;
  4use collections::HashMap;
  5use gpui::{
  6    list, AnyElement, Empty, ListAlignment, ListState, Model, StyleRefinement, Subscription,
  7    TextStyleRefinement, View, WeakView,
  8};
  9use language::LanguageRegistry;
 10use language_model::Role;
 11use markdown::{Markdown, MarkdownStyle};
 12use settings::Settings as _;
 13use theme::ThemeSettings;
 14use ui::prelude::*;
 15use workspace::Workspace;
 16
 17use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
 18
 19pub struct ActiveThread {
 20    workspace: WeakView<Workspace>,
 21    language_registry: Arc<LanguageRegistry>,
 22    tools: Arc<ToolWorkingSet>,
 23    thread: Model<Thread>,
 24    messages: Vec<MessageId>,
 25    list_state: ListState,
 26    rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
 27    last_error: Option<ThreadError>,
 28    _subscriptions: Vec<Subscription>,
 29}
 30
 31impl ActiveThread {
 32    pub fn new(
 33        thread: Model<Thread>,
 34        workspace: WeakView<Workspace>,
 35        language_registry: Arc<LanguageRegistry>,
 36        tools: Arc<ToolWorkingSet>,
 37        cx: &mut ViewContext<Self>,
 38    ) -> Self {
 39        let subscriptions = vec![
 40            cx.observe(&thread, |_, _, cx| cx.notify()),
 41            cx.subscribe(&thread, Self::handle_thread_event),
 42        ];
 43
 44        let mut this = Self {
 45            workspace,
 46            language_registry,
 47            tools,
 48            thread: thread.clone(),
 49            messages: Vec::new(),
 50            rendered_messages_by_id: HashMap::default(),
 51            list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
 52                let this = cx.view().downgrade();
 53                move |ix, cx: &mut WindowContext| {
 54                    this.update(cx, |this, cx| this.render_message(ix, cx))
 55                        .unwrap()
 56                }
 57            }),
 58            last_error: None,
 59            _subscriptions: subscriptions,
 60        };
 61
 62        for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
 63            this.push_message(&message.id, message.text.clone(), cx);
 64        }
 65
 66        this
 67    }
 68
 69    pub fn is_empty(&self) -> bool {
 70        self.messages.is_empty()
 71    }
 72
 73    pub fn last_error(&self) -> Option<ThreadError> {
 74        self.last_error.clone()
 75    }
 76
 77    pub fn clear_last_error(&mut self) {
 78        self.last_error.take();
 79    }
 80
 81    fn push_message(&mut self, id: &MessageId, text: String, cx: &mut ViewContext<Self>) {
 82        let old_len = self.messages.len();
 83        self.messages.push(*id);
 84        self.list_state.splice(old_len..old_len, 1);
 85
 86        let theme_settings = ThemeSettings::get_global(cx);
 87        let ui_font_size = TextSize::Default.rems(cx);
 88        let buffer_font_size = theme_settings.buffer_font_size;
 89
 90        let mut text_style = cx.text_style();
 91        text_style.refine(&TextStyleRefinement {
 92            font_family: Some(theme_settings.ui_font.family.clone()),
 93            font_size: Some(ui_font_size.into()),
 94            color: Some(cx.theme().colors().text),
 95            ..Default::default()
 96        });
 97
 98        let markdown_style = MarkdownStyle {
 99            base_text_style: text_style,
100            syntax: cx.theme().syntax().clone(),
101            selection_background_color: cx.theme().players().local().selection,
102            code_block: StyleRefinement {
103                text: Some(TextStyleRefinement {
104                    font_family: Some(theme_settings.buffer_font.family.clone()),
105                    font_size: Some(buffer_font_size.into()),
106                    ..Default::default()
107                }),
108                ..Default::default()
109            },
110            inline_code: TextStyleRefinement {
111                font_family: Some(theme_settings.buffer_font.family.clone()),
112                font_size: Some(ui_font_size.into()),
113                background_color: Some(cx.theme().colors().editor_background),
114                ..Default::default()
115            },
116            ..Default::default()
117        };
118
119        let markdown = cx.new_view(|cx| {
120            Markdown::new(
121                text,
122                markdown_style,
123                Some(self.language_registry.clone()),
124                None,
125                cx,
126            )
127        });
128        self.rendered_messages_by_id.insert(*id, markdown);
129    }
130
131    fn handle_thread_event(
132        &mut self,
133        _: Model<Thread>,
134        event: &ThreadEvent,
135        cx: &mut ViewContext<Self>,
136    ) {
137        match event {
138            ThreadEvent::ShowError(error) => {
139                self.last_error = Some(error.clone());
140            }
141            ThreadEvent::StreamedCompletion => {}
142            ThreadEvent::StreamedAssistantText(message_id, text) => {
143                if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
144                    markdown.update(cx, |markdown, cx| {
145                        markdown.append(text, cx);
146                    });
147                }
148            }
149            ThreadEvent::MessageAdded(message_id) => {
150                if let Some(message_text) = self
151                    .thread
152                    .read(cx)
153                    .message(*message_id)
154                    .map(|message| message.text.clone())
155                {
156                    self.push_message(message_id, message_text, cx);
157                }
158
159                cx.notify();
160            }
161            ThreadEvent::UsePendingTools => {
162                let pending_tool_uses = self
163                    .thread
164                    .read(cx)
165                    .pending_tool_uses()
166                    .into_iter()
167                    .filter(|tool_use| tool_use.status.is_idle())
168                    .cloned()
169                    .collect::<Vec<_>>();
170
171                for tool_use in pending_tool_uses {
172                    if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
173                        let task = tool.run(tool_use.input, self.workspace.clone(), cx);
174
175                        self.thread.update(cx, |thread, cx| {
176                            thread.insert_tool_output(
177                                tool_use.assistant_message_id,
178                                tool_use.id.clone(),
179                                task,
180                                cx,
181                            );
182                        });
183                    }
184                }
185            }
186            ThreadEvent::ToolFinished { .. } => {}
187        }
188    }
189
190    fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
191        let message_id = self.messages[ix];
192        let Some(message) = self.thread.read(cx).message(message_id) else {
193            return Empty.into_any();
194        };
195
196        let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
197            return Empty.into_any();
198        };
199
200        let (role_icon, role_name) = match message.role {
201            Role::User => (IconName::Person, "You"),
202            Role::Assistant => (IconName::ZedAssistant, "Assistant"),
203            Role::System => (IconName::Settings, "System"),
204        };
205
206        div()
207            .id(("message-container", ix))
208            .p_2()
209            .child(
210                v_flex()
211                    .border_1()
212                    .border_color(cx.theme().colors().border_variant)
213                    .rounded_md()
214                    .child(
215                        h_flex()
216                            .justify_between()
217                            .p_1p5()
218                            .border_b_1()
219                            .border_color(cx.theme().colors().border_variant)
220                            .child(
221                                h_flex()
222                                    .gap_2()
223                                    .child(Icon::new(role_icon).size(IconSize::Small))
224                                    .child(Label::new(role_name).size(LabelSize::Small)),
225                            ),
226                    )
227                    .child(v_flex().p_1p5().text_ui(cx).child(markdown.clone())),
228            )
229            .into_any()
230    }
231}
232
233impl Render for ActiveThread {
234    fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
235        list(self.list_state.clone()).flex_1()
236    }
237}