1use std::sync::Arc;
2
3use assistant_tool::ToolWorkingSet;
4use collections::HashMap;
5use gpui::{
6 list, AnyElement, AppContext, Empty, ListAlignment, ListState, Model, StyleRefinement,
7 Subscription, TextStyleRefinement, View, WeakView,
8};
9use language::LanguageRegistry;
10use language_model::Role;
11use markdown::{Markdown, MarkdownStyle};
12use settings::Settings as _;
13use theme::ThemeSettings;
14use ui::prelude::*;
15use workspace::Workspace;
16
17use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
18
19pub struct ActiveThread {
20 workspace: WeakView<Workspace>,
21 language_registry: Arc<LanguageRegistry>,
22 tools: Arc<ToolWorkingSet>,
23 thread: Model<Thread>,
24 messages: Vec<MessageId>,
25 list_state: ListState,
26 rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
27 last_error: Option<ThreadError>,
28 _subscriptions: Vec<Subscription>,
29}
30
31impl ActiveThread {
32 pub fn new(
33 thread: Model<Thread>,
34 workspace: WeakView<Workspace>,
35 language_registry: Arc<LanguageRegistry>,
36 tools: Arc<ToolWorkingSet>,
37 cx: &mut ViewContext<Self>,
38 ) -> Self {
39 let subscriptions = vec![
40 cx.observe(&thread, |_, _, cx| cx.notify()),
41 cx.subscribe(&thread, Self::handle_thread_event),
42 ];
43
44 let mut this = Self {
45 workspace,
46 language_registry,
47 tools,
48 thread: thread.clone(),
49 messages: Vec::new(),
50 rendered_messages_by_id: HashMap::default(),
51 list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
52 let this = cx.view().downgrade();
53 move |ix, cx: &mut WindowContext| {
54 this.update(cx, |this, cx| this.render_message(ix, cx))
55 .unwrap()
56 }
57 }),
58 last_error: None,
59 _subscriptions: subscriptions,
60 };
61
62 for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
63 this.push_message(&message.id, message.text.clone(), cx);
64 }
65
66 this
67 }
68
69 pub fn is_empty(&self) -> bool {
70 self.messages.is_empty()
71 }
72
73 pub fn summary(&self, cx: &AppContext) -> Option<SharedString> {
74 self.thread.read(cx).summary()
75 }
76
77 pub fn last_error(&self) -> Option<ThreadError> {
78 self.last_error.clone()
79 }
80
81 pub fn clear_last_error(&mut self) {
82 self.last_error.take();
83 }
84
85 fn push_message(&mut self, id: &MessageId, text: String, cx: &mut ViewContext<Self>) {
86 let old_len = self.messages.len();
87 self.messages.push(*id);
88 self.list_state.splice(old_len..old_len, 1);
89
90 let theme_settings = ThemeSettings::get_global(cx);
91 let ui_font_size = TextSize::Default.rems(cx);
92 let buffer_font_size = theme_settings.buffer_font_size;
93
94 let mut text_style = cx.text_style();
95 text_style.refine(&TextStyleRefinement {
96 font_family: Some(theme_settings.ui_font.family.clone()),
97 font_size: Some(ui_font_size.into()),
98 color: Some(cx.theme().colors().text),
99 ..Default::default()
100 });
101
102 let markdown_style = MarkdownStyle {
103 base_text_style: text_style,
104 syntax: cx.theme().syntax().clone(),
105 selection_background_color: cx.theme().players().local().selection,
106 code_block: StyleRefinement {
107 text: Some(TextStyleRefinement {
108 font_family: Some(theme_settings.buffer_font.family.clone()),
109 font_size: Some(buffer_font_size.into()),
110 ..Default::default()
111 }),
112 ..Default::default()
113 },
114 inline_code: TextStyleRefinement {
115 font_family: Some(theme_settings.buffer_font.family.clone()),
116 font_size: Some(ui_font_size.into()),
117 background_color: Some(cx.theme().colors().editor_background),
118 ..Default::default()
119 },
120 ..Default::default()
121 };
122
123 let markdown = cx.new_view(|cx| {
124 Markdown::new(
125 text,
126 markdown_style,
127 Some(self.language_registry.clone()),
128 None,
129 cx,
130 )
131 });
132 self.rendered_messages_by_id.insert(*id, markdown);
133 }
134
135 fn handle_thread_event(
136 &mut self,
137 _: Model<Thread>,
138 event: &ThreadEvent,
139 cx: &mut ViewContext<Self>,
140 ) {
141 match event {
142 ThreadEvent::ShowError(error) => {
143 self.last_error = Some(error.clone());
144 }
145 ThreadEvent::StreamedCompletion => {}
146 ThreadEvent::SummaryChanged => {}
147 ThreadEvent::StreamedAssistantText(message_id, text) => {
148 if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
149 markdown.update(cx, |markdown, cx| {
150 markdown.append(text, cx);
151 });
152 }
153 }
154 ThreadEvent::MessageAdded(message_id) => {
155 if let Some(message_text) = self
156 .thread
157 .read(cx)
158 .message(*message_id)
159 .map(|message| message.text.clone())
160 {
161 self.push_message(message_id, message_text, cx);
162 }
163
164 cx.notify();
165 }
166 ThreadEvent::UsePendingTools => {
167 let pending_tool_uses = self
168 .thread
169 .read(cx)
170 .pending_tool_uses()
171 .into_iter()
172 .filter(|tool_use| tool_use.status.is_idle())
173 .cloned()
174 .collect::<Vec<_>>();
175
176 for tool_use in pending_tool_uses {
177 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
178 let task = tool.run(tool_use.input, self.workspace.clone(), cx);
179
180 self.thread.update(cx, |thread, cx| {
181 thread.insert_tool_output(
182 tool_use.assistant_message_id,
183 tool_use.id.clone(),
184 task,
185 cx,
186 );
187 });
188 }
189 }
190 }
191 ThreadEvent::ToolFinished { .. } => {}
192 }
193 }
194
195 fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
196 let message_id = self.messages[ix];
197 let Some(message) = self.thread.read(cx).message(message_id) else {
198 return Empty.into_any();
199 };
200
201 let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
202 return Empty.into_any();
203 };
204
205 let (role_icon, role_name) = match message.role {
206 Role::User => (IconName::Person, "You"),
207 Role::Assistant => (IconName::ZedAssistant, "Assistant"),
208 Role::System => (IconName::Settings, "System"),
209 };
210
211 div()
212 .id(("message-container", ix))
213 .p_2()
214 .child(
215 v_flex()
216 .border_1()
217 .border_color(cx.theme().colors().border_variant)
218 .rounded_md()
219 .child(
220 h_flex()
221 .justify_between()
222 .p_1p5()
223 .border_b_1()
224 .border_color(cx.theme().colors().border_variant)
225 .child(
226 h_flex()
227 .gap_2()
228 .child(Icon::new(role_icon).size(IconSize::Small))
229 .child(Label::new(role_name).size(LabelSize::Small)),
230 ),
231 )
232 .child(v_flex().p_1p5().text_ui(cx).child(markdown.clone())),
233 )
234 .into_any()
235 }
236}
237
238impl Render for ActiveThread {
239 fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
240 list(self.list_state.clone()).flex_1()
241 }
242}