1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::ToolWorkingSet;
5use client::zed_urls;
6use collections::HashMap;
7use gpui::{
8 list, prelude::*, px, Action, AnyElement, AppContext, AsyncWindowContext, Empty, EventEmitter,
9 FocusHandle, FocusableView, FontWeight, ListAlignment, ListState, Model, Pixels,
10 StyleRefinement, Subscription, Task, TextStyleRefinement, View, ViewContext, WeakView,
11 WindowContext,
12};
13use language::LanguageRegistry;
14use language_model::{LanguageModelRegistry, Role};
15use language_model_selector::LanguageModelSelector;
16use markdown::{Markdown, MarkdownStyle};
17use settings::Settings;
18use theme::ThemeSettings;
19use ui::{prelude::*, ButtonLike, Divider, IconButtonShape, Tab, Tooltip};
20use workspace::dock::{DockPosition, Panel, PanelEvent};
21use workspace::Workspace;
22
23use crate::message_editor::MessageEditor;
24use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
25use crate::thread_store::ThreadStore;
26use crate::{NewThread, ToggleFocus, ToggleModelSelector};
27
28pub fn init(cx: &mut AppContext) {
29 cx.observe_new_views(
30 |workspace: &mut Workspace, _cx: &mut ViewContext<Workspace>| {
31 workspace.register_action(|workspace, _: &ToggleFocus, cx| {
32 workspace.toggle_panel_focus::<AssistantPanel>(cx);
33 });
34 },
35 )
36 .detach();
37}
38
39pub struct AssistantPanel {
40 workspace: WeakView<Workspace>,
41 language_registry: Arc<LanguageRegistry>,
42 #[allow(unused)]
43 thread_store: Model<ThreadStore>,
44 thread: Model<Thread>,
45 thread_messages: Vec<MessageId>,
46 rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
47 thread_list_state: ListState,
48 message_editor: View<MessageEditor>,
49 tools: Arc<ToolWorkingSet>,
50 last_error: Option<ThreadError>,
51 _subscriptions: Vec<Subscription>,
52}
53
54impl AssistantPanel {
55 pub fn load(
56 workspace: WeakView<Workspace>,
57 cx: AsyncWindowContext,
58 ) -> Task<Result<View<Self>>> {
59 cx.spawn(|mut cx| async move {
60 let tools = Arc::new(ToolWorkingSet::default());
61 let thread_store = workspace
62 .update(&mut cx, |workspace, cx| {
63 let project = workspace.project().clone();
64 ThreadStore::new(project, tools.clone(), cx)
65 })?
66 .await?;
67
68 workspace.update(&mut cx, |workspace, cx| {
69 cx.new_view(|cx| Self::new(workspace, thread_store, tools, cx))
70 })
71 })
72 }
73
74 fn new(
75 workspace: &Workspace,
76 thread_store: Model<ThreadStore>,
77 tools: Arc<ToolWorkingSet>,
78 cx: &mut ViewContext<Self>,
79 ) -> Self {
80 let thread = cx.new_model(|cx| Thread::new(tools.clone(), cx));
81 let subscriptions = vec![
82 cx.observe(&thread, |_, _, cx| cx.notify()),
83 cx.subscribe(&thread, Self::handle_thread_event),
84 ];
85
86 Self {
87 workspace: workspace.weak_handle(),
88 language_registry: workspace.project().read(cx).languages().clone(),
89 thread_store,
90 thread: thread.clone(),
91 thread_messages: Vec::new(),
92 rendered_messages_by_id: HashMap::default(),
93 thread_list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
94 let this = cx.view().downgrade();
95 move |ix, cx: &mut WindowContext| {
96 this.update(cx, |this, cx| this.render_message(ix, cx))
97 .unwrap()
98 }
99 }),
100 message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)),
101 tools,
102 last_error: None,
103 _subscriptions: subscriptions,
104 }
105 }
106
107 fn new_thread(&mut self, cx: &mut ViewContext<Self>) {
108 let tools = self.thread.read(cx).tools().clone();
109 let thread = cx.new_model(|cx| Thread::new(tools, cx));
110 let subscriptions = vec![
111 cx.observe(&thread, |_, _, cx| cx.notify()),
112 cx.subscribe(&thread, Self::handle_thread_event),
113 ];
114
115 self.message_editor = cx.new_view(|cx| MessageEditor::new(thread.clone(), cx));
116 self.thread = thread;
117 self.thread_messages.clear();
118 self.thread_list_state.reset(0);
119 self.rendered_messages_by_id.clear();
120 self._subscriptions = subscriptions;
121
122 self.message_editor.focus_handle(cx).focus(cx);
123 }
124
125 fn handle_thread_event(
126 &mut self,
127 _: Model<Thread>,
128 event: &ThreadEvent,
129 cx: &mut ViewContext<Self>,
130 ) {
131 match event {
132 ThreadEvent::ShowError(error) => {
133 self.last_error = Some(error.clone());
134 }
135 ThreadEvent::StreamedCompletion => {}
136 ThreadEvent::StreamedAssistantText(message_id, text) => {
137 if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
138 markdown.update(cx, |markdown, cx| {
139 markdown.append(text, cx);
140 });
141 }
142 }
143 ThreadEvent::MessageAdded(message_id) => {
144 let old_len = self.thread_messages.len();
145 self.thread_messages.push(*message_id);
146 self.thread_list_state.splice(old_len..old_len, 1);
147
148 if let Some(message_text) = self
149 .thread
150 .read(cx)
151 .message(*message_id)
152 .map(|message| message.text.clone())
153 {
154 let theme_settings = ThemeSettings::get_global(cx);
155 let ui_font_size = TextSize::Default.rems(cx);
156 let buffer_font_size = theme_settings.buffer_font_size;
157
158 let mut text_style = cx.text_style();
159 text_style.refine(&TextStyleRefinement {
160 font_family: Some(theme_settings.ui_font.family.clone()),
161 font_size: Some(ui_font_size.into()),
162 color: Some(cx.theme().colors().text),
163 ..Default::default()
164 });
165
166 let markdown_style = MarkdownStyle {
167 base_text_style: text_style,
168 syntax: cx.theme().syntax().clone(),
169 selection_background_color: cx.theme().players().local().selection,
170 code_block: StyleRefinement {
171 text: Some(TextStyleRefinement {
172 font_family: Some(theme_settings.buffer_font.family.clone()),
173 font_size: Some(buffer_font_size.into()),
174 ..Default::default()
175 }),
176 ..Default::default()
177 },
178 inline_code: TextStyleRefinement {
179 font_family: Some(theme_settings.buffer_font.family.clone()),
180 font_size: Some(ui_font_size.into()),
181 background_color: Some(cx.theme().colors().editor_background),
182 ..Default::default()
183 },
184 ..Default::default()
185 };
186
187 let markdown = cx.new_view(|cx| {
188 Markdown::new(
189 message_text,
190 markdown_style,
191 Some(self.language_registry.clone()),
192 None,
193 cx,
194 )
195 });
196 self.rendered_messages_by_id.insert(*message_id, markdown);
197 }
198
199 cx.notify();
200 }
201 ThreadEvent::UsePendingTools => {
202 let pending_tool_uses = self
203 .thread
204 .read(cx)
205 .pending_tool_uses()
206 .into_iter()
207 .filter(|tool_use| tool_use.status.is_idle())
208 .cloned()
209 .collect::<Vec<_>>();
210
211 for tool_use in pending_tool_uses {
212 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
213 let task = tool.run(tool_use.input, self.workspace.clone(), cx);
214
215 self.thread.update(cx, |thread, cx| {
216 thread.insert_tool_output(
217 tool_use.assistant_message_id,
218 tool_use.id.clone(),
219 task,
220 cx,
221 );
222 });
223 }
224 }
225 }
226 ThreadEvent::ToolFinished { .. } => {}
227 }
228 }
229}
230
231impl FocusableView for AssistantPanel {
232 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
233 self.message_editor.focus_handle(cx)
234 }
235}
236
237impl EventEmitter<PanelEvent> for AssistantPanel {}
238
239impl Panel for AssistantPanel {
240 fn persistent_name() -> &'static str {
241 "AssistantPanel2"
242 }
243
244 fn position(&self, _cx: &WindowContext) -> DockPosition {
245 DockPosition::Right
246 }
247
248 fn position_is_valid(&self, _: DockPosition) -> bool {
249 true
250 }
251
252 fn set_position(&mut self, _position: DockPosition, _cx: &mut ViewContext<Self>) {}
253
254 fn size(&self, _cx: &WindowContext) -> Pixels {
255 px(640.)
256 }
257
258 fn set_size(&mut self, _size: Option<Pixels>, _cx: &mut ViewContext<Self>) {}
259
260 fn set_active(&mut self, _active: bool, _cx: &mut ViewContext<Self>) {}
261
262 fn remote_id() -> Option<proto::PanelId> {
263 Some(proto::PanelId::AssistantPanel)
264 }
265
266 fn icon(&self, _cx: &WindowContext) -> Option<IconName> {
267 Some(IconName::ZedAssistant)
268 }
269
270 fn icon_tooltip(&self, _cx: &WindowContext) -> Option<&'static str> {
271 Some("Assistant Panel")
272 }
273
274 fn toggle_action(&self) -> Box<dyn Action> {
275 Box::new(ToggleFocus)
276 }
277}
278
279impl AssistantPanel {
280 fn render_toolbar(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
281 let focus_handle = self.focus_handle(cx);
282
283 h_flex()
284 .id("assistant-toolbar")
285 .justify_between()
286 .gap(DynamicSpacing::Base08.rems(cx))
287 .h(Tab::container_height(cx))
288 .px(DynamicSpacing::Base08.rems(cx))
289 .bg(cx.theme().colors().tab_bar_background)
290 .border_b_1()
291 .border_color(cx.theme().colors().border_variant)
292 .child(h_flex().child(Label::new("Thread Title Goes Here")))
293 .child(
294 h_flex()
295 .gap(DynamicSpacing::Base08.rems(cx))
296 .child(self.render_language_model_selector(cx))
297 .child(Divider::vertical())
298 .child(
299 IconButton::new("new-thread", IconName::Plus)
300 .shape(IconButtonShape::Square)
301 .icon_size(IconSize::Small)
302 .style(ButtonStyle::Subtle)
303 .tooltip({
304 let focus_handle = focus_handle.clone();
305 move |cx| {
306 Tooltip::for_action_in(
307 "New Thread",
308 &NewThread,
309 &focus_handle,
310 cx,
311 )
312 }
313 })
314 .on_click(move |_event, _cx| {
315 println!("New Thread");
316 }),
317 )
318 .child(
319 IconButton::new("open-history", IconName::HistoryRerun)
320 .shape(IconButtonShape::Square)
321 .icon_size(IconSize::Small)
322 .style(ButtonStyle::Subtle)
323 .tooltip(move |cx| Tooltip::text("Open History", cx))
324 .on_click(move |_event, _cx| {
325 println!("Open History");
326 }),
327 )
328 .child(
329 IconButton::new("configure-assistant", IconName::Settings)
330 .shape(IconButtonShape::Square)
331 .icon_size(IconSize::Small)
332 .style(ButtonStyle::Subtle)
333 .tooltip(move |cx| Tooltip::text("Configure Assistant", cx))
334 .on_click(move |_event, _cx| {
335 println!("Configure Assistant");
336 }),
337 ),
338 )
339 }
340
341 fn render_language_model_selector(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
342 let active_provider = LanguageModelRegistry::read_global(cx).active_provider();
343 let active_model = LanguageModelRegistry::read_global(cx).active_model();
344
345 LanguageModelSelector::new(
346 |model, _cx| {
347 println!("Selected {:?}", model.name());
348 },
349 ButtonLike::new("active-model")
350 .style(ButtonStyle::Subtle)
351 .child(
352 h_flex()
353 .w_full()
354 .gap_0p5()
355 .child(
356 div()
357 .overflow_x_hidden()
358 .flex_grow()
359 .whitespace_nowrap()
360 .child(match (active_provider, active_model) {
361 (Some(provider), Some(model)) => h_flex()
362 .gap_1()
363 .child(
364 Icon::new(
365 model.icon().unwrap_or_else(|| provider.icon()),
366 )
367 .color(Color::Muted)
368 .size(IconSize::XSmall),
369 )
370 .child(
371 Label::new(model.name().0)
372 .size(LabelSize::Small)
373 .color(Color::Muted),
374 )
375 .into_any_element(),
376 _ => Label::new("No model selected")
377 .size(LabelSize::Small)
378 .color(Color::Muted)
379 .into_any_element(),
380 }),
381 )
382 .child(
383 Icon::new(IconName::ChevronDown)
384 .color(Color::Muted)
385 .size(IconSize::XSmall),
386 ),
387 )
388 .tooltip(move |cx| Tooltip::for_action("Change Model", &ToggleModelSelector, cx)),
389 )
390 }
391
392 fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
393 let message_id = self.thread_messages[ix];
394 let Some(message) = self.thread.read(cx).message(message_id) else {
395 return Empty.into_any();
396 };
397
398 let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
399 return Empty.into_any();
400 };
401
402 let (role_icon, role_name) = match message.role {
403 Role::User => (IconName::Person, "You"),
404 Role::Assistant => (IconName::ZedAssistant, "Assistant"),
405 Role::System => (IconName::Settings, "System"),
406 };
407
408 div()
409 .id(("message-container", ix))
410 .p_2()
411 .child(
412 v_flex()
413 .border_1()
414 .border_color(cx.theme().colors().border_variant)
415 .rounded_md()
416 .child(
417 h_flex()
418 .justify_between()
419 .p_1p5()
420 .border_b_1()
421 .border_color(cx.theme().colors().border_variant)
422 .child(
423 h_flex()
424 .gap_2()
425 .child(Icon::new(role_icon).size(IconSize::Small))
426 .child(Label::new(role_name).size(LabelSize::Small)),
427 ),
428 )
429 .child(v_flex().p_1p5().text_ui(cx).child(markdown.clone())),
430 )
431 .into_any()
432 }
433
434 fn render_last_error(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
435 let last_error = self.last_error.as_ref()?;
436
437 Some(
438 div()
439 .absolute()
440 .right_3()
441 .bottom_12()
442 .max_w_96()
443 .py_2()
444 .px_3()
445 .elevation_2(cx)
446 .occlude()
447 .child(match last_error {
448 ThreadError::PaymentRequired => self.render_payment_required_error(cx),
449 ThreadError::MaxMonthlySpendReached => {
450 self.render_max_monthly_spend_reached_error(cx)
451 }
452 ThreadError::Message(error_message) => {
453 self.render_error_message(error_message, cx)
454 }
455 })
456 .into_any(),
457 )
458 }
459
460 fn render_payment_required_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
461 const ERROR_MESSAGE: &str = "Free tier exceeded. Subscribe and add payment to continue using Zed LLMs. You'll be billed at cost for tokens used.";
462
463 v_flex()
464 .gap_0p5()
465 .child(
466 h_flex()
467 .gap_1p5()
468 .items_center()
469 .child(Icon::new(IconName::XCircle).color(Color::Error))
470 .child(Label::new("Free Usage Exceeded").weight(FontWeight::MEDIUM)),
471 )
472 .child(
473 div()
474 .id("error-message")
475 .max_h_24()
476 .overflow_y_scroll()
477 .child(Label::new(ERROR_MESSAGE)),
478 )
479 .child(
480 h_flex()
481 .justify_end()
482 .mt_1()
483 .child(Button::new("subscribe", "Subscribe").on_click(cx.listener(
484 |this, _, cx| {
485 this.last_error = None;
486 cx.open_url(&zed_urls::account_url(cx));
487 cx.notify();
488 },
489 )))
490 .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
491 |this, _, cx| {
492 this.last_error = None;
493 cx.notify();
494 },
495 ))),
496 )
497 .into_any()
498 }
499
500 fn render_max_monthly_spend_reached_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
501 const ERROR_MESSAGE: &str = "You have reached your maximum monthly spend. Increase your spend limit to continue using Zed LLMs.";
502
503 v_flex()
504 .gap_0p5()
505 .child(
506 h_flex()
507 .gap_1p5()
508 .items_center()
509 .child(Icon::new(IconName::XCircle).color(Color::Error))
510 .child(Label::new("Max Monthly Spend Reached").weight(FontWeight::MEDIUM)),
511 )
512 .child(
513 div()
514 .id("error-message")
515 .max_h_24()
516 .overflow_y_scroll()
517 .child(Label::new(ERROR_MESSAGE)),
518 )
519 .child(
520 h_flex()
521 .justify_end()
522 .mt_1()
523 .child(
524 Button::new("subscribe", "Update Monthly Spend Limit").on_click(
525 cx.listener(|this, _, cx| {
526 this.last_error = None;
527 cx.open_url(&zed_urls::account_url(cx));
528 cx.notify();
529 }),
530 ),
531 )
532 .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
533 |this, _, cx| {
534 this.last_error = None;
535 cx.notify();
536 },
537 ))),
538 )
539 .into_any()
540 }
541
542 fn render_error_message(
543 &self,
544 error_message: &SharedString,
545 cx: &mut ViewContext<Self>,
546 ) -> AnyElement {
547 v_flex()
548 .gap_0p5()
549 .child(
550 h_flex()
551 .gap_1p5()
552 .items_center()
553 .child(Icon::new(IconName::XCircle).color(Color::Error))
554 .child(
555 Label::new("Error interacting with language model")
556 .weight(FontWeight::MEDIUM),
557 ),
558 )
559 .child(
560 div()
561 .id("error-message")
562 .max_h_32()
563 .overflow_y_scroll()
564 .child(Label::new(error_message.clone())),
565 )
566 .child(
567 h_flex()
568 .justify_end()
569 .mt_1()
570 .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
571 |this, _, cx| {
572 this.last_error = None;
573 cx.notify();
574 },
575 ))),
576 )
577 .into_any()
578 }
579}
580
581impl Render for AssistantPanel {
582 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
583 v_flex()
584 .key_context("AssistantPanel2")
585 .justify_between()
586 .size_full()
587 .on_action(cx.listener(|this, _: &NewThread, cx| {
588 this.new_thread(cx);
589 }))
590 .child(self.render_toolbar(cx))
591 .child(list(self.thread_list_state.clone()).flex_1())
592 .child(
593 h_flex()
594 .border_t_1()
595 .border_color(cx.theme().colors().border_variant)
596 .child(self.message_editor.clone()),
597 )
598 .children(self.render_last_error(cx))
599 }
600}