active_thread.rs

  1use std::sync::Arc;
  2
  3use assistant_tool::ToolWorkingSet;
  4use collections::HashMap;
  5use gpui::{
  6    list, AnyElement, AppContext, Empty, ListAlignment, ListState, Model, StyleRefinement,
  7    Subscription, TextStyleRefinement, View, WeakView,
  8};
  9use language::LanguageRegistry;
 10use language_model::Role;
 11use markdown::{Markdown, MarkdownStyle};
 12use settings::Settings as _;
 13use theme::ThemeSettings;
 14use ui::prelude::*;
 15use workspace::Workspace;
 16
 17use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
 18
 19pub struct ActiveThread {
 20    workspace: WeakView<Workspace>,
 21    language_registry: Arc<LanguageRegistry>,
 22    tools: Arc<ToolWorkingSet>,
 23    thread: Model<Thread>,
 24    messages: Vec<MessageId>,
 25    list_state: ListState,
 26    rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
 27    last_error: Option<ThreadError>,
 28    _subscriptions: Vec<Subscription>,
 29}
 30
 31impl ActiveThread {
 32    pub fn new(
 33        thread: Model<Thread>,
 34        workspace: WeakView<Workspace>,
 35        language_registry: Arc<LanguageRegistry>,
 36        tools: Arc<ToolWorkingSet>,
 37        cx: &mut ViewContext<Self>,
 38    ) -> Self {
 39        let subscriptions = vec![
 40            cx.observe(&thread, |_, _, cx| cx.notify()),
 41            cx.subscribe(&thread, Self::handle_thread_event),
 42        ];
 43
 44        let mut this = Self {
 45            workspace,
 46            language_registry,
 47            tools,
 48            thread: thread.clone(),
 49            messages: Vec::new(),
 50            rendered_messages_by_id: HashMap::default(),
 51            list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
 52                let this = cx.view().downgrade();
 53                move |ix, cx: &mut WindowContext| {
 54                    this.update(cx, |this, cx| this.render_message(ix, cx))
 55                        .unwrap()
 56                }
 57            }),
 58            last_error: None,
 59            _subscriptions: subscriptions,
 60        };
 61
 62        for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
 63            this.push_message(&message.id, message.text.clone(), cx);
 64        }
 65
 66        this
 67    }
 68
 69    pub fn is_empty(&self) -> bool {
 70        self.messages.is_empty()
 71    }
 72
 73    pub fn summary(&self, cx: &AppContext) -> Option<SharedString> {
 74        self.thread.read(cx).summary()
 75    }
 76
 77    pub fn last_error(&self) -> Option<ThreadError> {
 78        self.last_error.clone()
 79    }
 80
 81    pub fn clear_last_error(&mut self) {
 82        self.last_error.take();
 83    }
 84
 85    fn push_message(&mut self, id: &MessageId, text: String, cx: &mut ViewContext<Self>) {
 86        let old_len = self.messages.len();
 87        self.messages.push(*id);
 88        self.list_state.splice(old_len..old_len, 1);
 89
 90        let theme_settings = ThemeSettings::get_global(cx);
 91        let ui_font_size = TextSize::Default.rems(cx);
 92        let buffer_font_size = theme_settings.buffer_font_size;
 93
 94        let mut text_style = cx.text_style();
 95        text_style.refine(&TextStyleRefinement {
 96            font_family: Some(theme_settings.ui_font.family.clone()),
 97            font_size: Some(ui_font_size.into()),
 98            color: Some(cx.theme().colors().text),
 99            ..Default::default()
100        });
101
102        let markdown_style = MarkdownStyle {
103            base_text_style: text_style,
104            syntax: cx.theme().syntax().clone(),
105            selection_background_color: cx.theme().players().local().selection,
106            code_block: StyleRefinement {
107                text: Some(TextStyleRefinement {
108                    font_family: Some(theme_settings.buffer_font.family.clone()),
109                    font_size: Some(buffer_font_size.into()),
110                    ..Default::default()
111                }),
112                ..Default::default()
113            },
114            inline_code: TextStyleRefinement {
115                font_family: Some(theme_settings.buffer_font.family.clone()),
116                font_size: Some(ui_font_size.into()),
117                background_color: Some(cx.theme().colors().editor_background),
118                ..Default::default()
119            },
120            ..Default::default()
121        };
122
123        let markdown = cx.new_view(|cx| {
124            Markdown::new(
125                text,
126                markdown_style,
127                Some(self.language_registry.clone()),
128                None,
129                cx,
130            )
131        });
132        self.rendered_messages_by_id.insert(*id, markdown);
133    }
134
135    fn handle_thread_event(
136        &mut self,
137        _: Model<Thread>,
138        event: &ThreadEvent,
139        cx: &mut ViewContext<Self>,
140    ) {
141        match event {
142            ThreadEvent::ShowError(error) => {
143                self.last_error = Some(error.clone());
144            }
145            ThreadEvent::StreamedCompletion => {}
146            ThreadEvent::SummaryChanged => {}
147            ThreadEvent::StreamedAssistantText(message_id, text) => {
148                if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
149                    markdown.update(cx, |markdown, cx| {
150                        markdown.append(text, cx);
151                    });
152                }
153            }
154            ThreadEvent::MessageAdded(message_id) => {
155                if let Some(message_text) = self
156                    .thread
157                    .read(cx)
158                    .message(*message_id)
159                    .map(|message| message.text.clone())
160                {
161                    self.push_message(message_id, message_text, cx);
162                }
163
164                cx.notify();
165            }
166            ThreadEvent::UsePendingTools => {
167                let pending_tool_uses = self
168                    .thread
169                    .read(cx)
170                    .pending_tool_uses()
171                    .into_iter()
172                    .filter(|tool_use| tool_use.status.is_idle())
173                    .cloned()
174                    .collect::<Vec<_>>();
175
176                for tool_use in pending_tool_uses {
177                    if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
178                        let task = tool.run(tool_use.input, self.workspace.clone(), cx);
179
180                        self.thread.update(cx, |thread, cx| {
181                            thread.insert_tool_output(
182                                tool_use.assistant_message_id,
183                                tool_use.id.clone(),
184                                task,
185                                cx,
186                            );
187                        });
188                    }
189                }
190            }
191            ThreadEvent::ToolFinished { .. } => {}
192        }
193    }
194
195    fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
196        let message_id = self.messages[ix];
197        let Some(message) = self.thread.read(cx).message(message_id) else {
198            return Empty.into_any();
199        };
200
201        let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
202            return Empty.into_any();
203        };
204
205        let (role_icon, role_name) = match message.role {
206            Role::User => (IconName::Person, "You"),
207            Role::Assistant => (IconName::ZedAssistant, "Assistant"),
208            Role::System => (IconName::Settings, "System"),
209        };
210
211        div()
212            .id(("message-container", ix))
213            .p_2()
214            .child(
215                v_flex()
216                    .border_1()
217                    .border_color(cx.theme().colors().border_variant)
218                    .rounded_md()
219                    .child(
220                        h_flex()
221                            .justify_between()
222                            .p_1p5()
223                            .border_b_1()
224                            .border_color(cx.theme().colors().border_variant)
225                            .child(
226                                h_flex()
227                                    .gap_2()
228                                    .child(Icon::new(role_icon).size(IconSize::Small))
229                                    .child(Label::new(role_name).size(LabelSize::Small)),
230                            ),
231                    )
232                    .child(v_flex().p_1p5().text_ui(cx).child(markdown.clone())),
233            )
234            .into_any()
235    }
236}
237
238impl Render for ActiveThread {
239    fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
240        list(self.list_state.clone()).flex_1()
241    }
242}