1use std::sync::Arc;
2
3use assistant_tool::ToolWorkingSet;
4use collections::HashMap;
5use gpui::{
6 list, AnyElement, AppContext, Empty, ListAlignment, ListState, Model, StyleRefinement,
7 Subscription, TextStyleRefinement, View, WeakView,
8};
9use language::LanguageRegistry;
10use language_model::Role;
11use markdown::{Markdown, MarkdownStyle};
12use settings::Settings as _;
13use theme::ThemeSettings;
14use ui::prelude::*;
15use workspace::Workspace;
16
17use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
18use crate::ui::ContextPill;
19
20pub struct ActiveThread {
21 workspace: WeakView<Workspace>,
22 language_registry: Arc<LanguageRegistry>,
23 tools: Arc<ToolWorkingSet>,
24 thread: Model<Thread>,
25 messages: Vec<MessageId>,
26 list_state: ListState,
27 rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
28 last_error: Option<ThreadError>,
29 _subscriptions: Vec<Subscription>,
30}
31
32impl ActiveThread {
33 pub fn new(
34 thread: Model<Thread>,
35 workspace: WeakView<Workspace>,
36 language_registry: Arc<LanguageRegistry>,
37 tools: Arc<ToolWorkingSet>,
38 cx: &mut ViewContext<Self>,
39 ) -> Self {
40 let subscriptions = vec![
41 cx.observe(&thread, |_, _, cx| cx.notify()),
42 cx.subscribe(&thread, Self::handle_thread_event),
43 ];
44
45 let mut this = Self {
46 workspace,
47 language_registry,
48 tools,
49 thread: thread.clone(),
50 messages: Vec::new(),
51 rendered_messages_by_id: HashMap::default(),
52 list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
53 let this = cx.view().downgrade();
54 move |ix, cx: &mut WindowContext| {
55 this.update(cx, |this, cx| this.render_message(ix, cx))
56 .unwrap()
57 }
58 }),
59 last_error: None,
60 _subscriptions: subscriptions,
61 };
62
63 for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
64 this.push_message(&message.id, message.text.clone(), cx);
65 }
66
67 this
68 }
69
70 pub fn is_empty(&self) -> bool {
71 self.messages.is_empty()
72 }
73
74 pub fn summary(&self, cx: &AppContext) -> Option<SharedString> {
75 self.thread.read(cx).summary()
76 }
77
78 pub fn last_error(&self) -> Option<ThreadError> {
79 self.last_error.clone()
80 }
81
82 pub fn clear_last_error(&mut self) {
83 self.last_error.take();
84 }
85
86 fn push_message(&mut self, id: &MessageId, text: String, cx: &mut ViewContext<Self>) {
87 let old_len = self.messages.len();
88 self.messages.push(*id);
89 self.list_state.splice(old_len..old_len, 1);
90
91 let theme_settings = ThemeSettings::get_global(cx);
92 let ui_font_size = TextSize::Default.rems(cx);
93 let buffer_font_size = theme_settings.buffer_font_size;
94
95 let mut text_style = cx.text_style();
96 text_style.refine(&TextStyleRefinement {
97 font_family: Some(theme_settings.ui_font.family.clone()),
98 font_size: Some(ui_font_size.into()),
99 color: Some(cx.theme().colors().text),
100 ..Default::default()
101 });
102
103 let markdown_style = MarkdownStyle {
104 base_text_style: text_style,
105 syntax: cx.theme().syntax().clone(),
106 selection_background_color: cx.theme().players().local().selection,
107 code_block: StyleRefinement {
108 text: Some(TextStyleRefinement {
109 font_family: Some(theme_settings.buffer_font.family.clone()),
110 font_size: Some(buffer_font_size.into()),
111 ..Default::default()
112 }),
113 ..Default::default()
114 },
115 inline_code: TextStyleRefinement {
116 font_family: Some(theme_settings.buffer_font.family.clone()),
117 font_size: Some(ui_font_size.into()),
118 background_color: Some(cx.theme().colors().editor_background),
119 ..Default::default()
120 },
121 ..Default::default()
122 };
123
124 let markdown = cx.new_view(|cx| {
125 Markdown::new(
126 text,
127 markdown_style,
128 Some(self.language_registry.clone()),
129 None,
130 cx,
131 )
132 });
133 self.rendered_messages_by_id.insert(*id, markdown);
134 }
135
136 fn handle_thread_event(
137 &mut self,
138 _: Model<Thread>,
139 event: &ThreadEvent,
140 cx: &mut ViewContext<Self>,
141 ) {
142 match event {
143 ThreadEvent::ShowError(error) => {
144 self.last_error = Some(error.clone());
145 }
146 ThreadEvent::StreamedCompletion => {}
147 ThreadEvent::SummaryChanged => {}
148 ThreadEvent::StreamedAssistantText(message_id, text) => {
149 if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
150 markdown.update(cx, |markdown, cx| {
151 markdown.append(text, cx);
152 });
153 }
154 }
155 ThreadEvent::MessageAdded(message_id) => {
156 if let Some(message_text) = self
157 .thread
158 .read(cx)
159 .message(*message_id)
160 .map(|message| message.text.clone())
161 {
162 self.push_message(message_id, message_text, cx);
163 }
164
165 cx.notify();
166 }
167 ThreadEvent::UsePendingTools => {
168 let pending_tool_uses = self
169 .thread
170 .read(cx)
171 .pending_tool_uses()
172 .into_iter()
173 .filter(|tool_use| tool_use.status.is_idle())
174 .cloned()
175 .collect::<Vec<_>>();
176
177 for tool_use in pending_tool_uses {
178 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
179 let task = tool.run(tool_use.input, self.workspace.clone(), cx);
180
181 self.thread.update(cx, |thread, cx| {
182 thread.insert_tool_output(
183 tool_use.assistant_message_id,
184 tool_use.id.clone(),
185 task,
186 cx,
187 );
188 });
189 }
190 }
191 }
192 ThreadEvent::ToolFinished { .. } => {}
193 }
194 }
195
196 fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
197 let message_id = self.messages[ix];
198 let Some(message) = self.thread.read(cx).message(message_id) else {
199 return Empty.into_any();
200 };
201
202 let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
203 return Empty.into_any();
204 };
205
206 let context = self.thread.read(cx).context_for_message(message_id);
207
208 let (role_icon, role_name) = match message.role {
209 Role::User => (IconName::Person, "You"),
210 Role::Assistant => (IconName::ZedAssistant, "Assistant"),
211 Role::System => (IconName::Settings, "System"),
212 };
213
214 div()
215 .id(("message-container", ix))
216 .p_2()
217 .child(
218 v_flex()
219 .border_1()
220 .border_color(cx.theme().colors().border_variant)
221 .rounded_md()
222 .child(
223 h_flex()
224 .justify_between()
225 .p_1p5()
226 .border_b_1()
227 .border_color(cx.theme().colors().border_variant)
228 .child(
229 h_flex()
230 .gap_2()
231 .child(Icon::new(role_icon).size(IconSize::Small))
232 .child(Label::new(role_name).size(LabelSize::Small)),
233 ),
234 )
235 .child(v_flex().p_1p5().text_ui(cx).child(markdown.clone()))
236 .when_some(context, |parent, context| {
237 parent.child(
238 h_flex().flex_wrap().gap_2().p_1p5().children(
239 context
240 .iter()
241 .map(|context| ContextPill::new(context.clone())),
242 ),
243 )
244 }),
245 )
246 .into_any()
247 }
248}
249
250impl Render for ActiveThread {
251 fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
252 list(self.list_state.clone()).flex_1()
253 }
254}