assistant_configuration.rs

  1use std::sync::Arc;
  2
  3use collections::HashMap;
  4use gpui::{Action, AnyView, App, EventEmitter, FocusHandle, Focusable, Subscription};
  5use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
  6use ui::{prelude::*, Divider, DividerColor, ElevationIndex};
  7use zed_actions::assistant::DeployPromptLibrary;
  8
  9pub struct AssistantConfiguration {
 10    focus_handle: FocusHandle,
 11    configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
 12    _registry_subscription: Subscription,
 13}
 14
 15impl AssistantConfiguration {
 16    pub fn new(window: &mut Window, cx: &mut Context<Self>) -> Self {
 17        let focus_handle = cx.focus_handle();
 18
 19        let registry_subscription = cx.subscribe_in(
 20            &LanguageModelRegistry::global(cx),
 21            window,
 22            |this, _, event: &language_model::Event, window, cx| match event {
 23                language_model::Event::AddedProvider(provider_id) => {
 24                    let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
 25                    if let Some(provider) = provider {
 26                        this.add_provider_configuration_view(&provider, window, cx);
 27                    }
 28                }
 29                language_model::Event::RemovedProvider(provider_id) => {
 30                    this.remove_provider_configuration_view(provider_id);
 31                }
 32                _ => {}
 33            },
 34        );
 35
 36        let mut this = Self {
 37            focus_handle,
 38            configuration_views_by_provider: HashMap::default(),
 39            _registry_subscription: registry_subscription,
 40        };
 41        this.build_provider_configuration_views(window, cx);
 42        this
 43    }
 44
 45    fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 46        let providers = LanguageModelRegistry::read_global(cx).providers();
 47        for provider in providers {
 48            self.add_provider_configuration_view(&provider, window, cx);
 49        }
 50    }
 51
 52    fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
 53        self.configuration_views_by_provider.remove(provider_id);
 54    }
 55
 56    fn add_provider_configuration_view(
 57        &mut self,
 58        provider: &Arc<dyn LanguageModelProvider>,
 59        window: &mut Window,
 60        cx: &mut Context<Self>,
 61    ) {
 62        let configuration_view = provider.configuration_view(window, cx);
 63        self.configuration_views_by_provider
 64            .insert(provider.id(), configuration_view);
 65    }
 66}
 67
 68impl Focusable for AssistantConfiguration {
 69    fn focus_handle(&self, _: &App) -> FocusHandle {
 70        self.focus_handle.clone()
 71    }
 72}
 73
 74pub enum AssistantConfigurationEvent {
 75    NewThread(Arc<dyn LanguageModelProvider>),
 76}
 77
 78impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
 79
 80impl AssistantConfiguration {
 81    fn render_provider_configuration(
 82        &mut self,
 83        provider: &Arc<dyn LanguageModelProvider>,
 84        cx: &mut Context<Self>,
 85    ) -> impl IntoElement {
 86        let provider_id = provider.id().0.clone();
 87        let provider_name = provider.name().0.clone();
 88        let configuration_view = self
 89            .configuration_views_by_provider
 90            .get(&provider.id())
 91            .cloned();
 92
 93        v_flex()
 94            .gap_1p5()
 95            .child(
 96                h_flex()
 97                    .justify_between()
 98                    .child(
 99                        h_flex()
100                            .gap_2()
101                            .child(
102                                Icon::new(provider.icon())
103                                    .size(IconSize::Small)
104                                    .color(Color::Muted),
105                            )
106                            .child(Label::new(provider_name.clone())),
107                    )
108                    .when(provider.is_authenticated(cx), |parent| {
109                        parent.child(
110                            Button::new(
111                                SharedString::from(format!("new-thread-{provider_id}")),
112                                "Start New Thread",
113                            )
114                            .icon_position(IconPosition::Start)
115                            .icon(IconName::Plus)
116                            .icon_size(IconSize::Small)
117                            .style(ButtonStyle::Filled)
118                            .layer(ElevationIndex::ModalSurface)
119                            .label_size(LabelSize::Small)
120                            .on_click(cx.listener({
121                                let provider = provider.clone();
122                                move |_this, _event, _window, cx| {
123                                    cx.emit(AssistantConfigurationEvent::NewThread(
124                                        provider.clone(),
125                                    ))
126                                }
127                            })),
128                        )
129                    }),
130            )
131            .child(
132                div()
133                    .p(DynamicSpacing::Base08.rems(cx))
134                    .bg(cx.theme().colors().editor_background)
135                    .border_1()
136                    .border_color(cx.theme().colors().border_variant)
137                    .rounded_sm()
138                    .map(|parent| match configuration_view {
139                        Some(configuration_view) => parent.child(configuration_view),
140                        None => parent.child(div().child(Label::new(format!(
141                            "No configuration view for {provider_name}",
142                        )))),
143                    }),
144            )
145    }
146}
147
148impl Render for AssistantConfiguration {
149    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
150        let providers = LanguageModelRegistry::read_global(cx).providers();
151
152        v_flex()
153            .id("assistant-configuration")
154            .track_focus(&self.focus_handle(cx))
155            .bg(cx.theme().colors().panel_background)
156            .size_full()
157            .overflow_y_scroll()
158            .child(
159                v_flex()
160                    .p(DynamicSpacing::Base16.rems(cx))
161                    .gap_2()
162                    .child(
163                        v_flex()
164                            .gap_0p5()
165                            .child(Headline::new("Prompt Library").size(HeadlineSize::Small))
166                            .child(
167                                Label::new("Create reusable prompts and tag which ones you want sent in every LLM interaction.")
168                                    .color(Color::Muted),
169                            ),
170                    )
171                    .child(
172                        Button::new("open-prompt-library", "Open Prompt Library")
173                            .style(ButtonStyle::Filled)
174                            .layer(ElevationIndex::ModalSurface)
175                            .full_width()
176                            .icon(IconName::Book)
177                            .icon_size(IconSize::Small)
178                            .icon_position(IconPosition::Start)
179                            .on_click(|_event, window, cx| {
180                                window.dispatch_action(DeployPromptLibrary.boxed_clone(), cx)
181                            }),
182                    ),
183            )
184            .child(Divider::horizontal().color(DividerColor::Border))
185            .child(
186                v_flex()
187                    .p(DynamicSpacing::Base16.rems(cx))
188                    .mt_1()
189                    .gap_6()
190                    .flex_1()
191                    .child(
192                        v_flex()
193                            .gap_0p5()
194                            .child(Headline::new("LLM Providers").size(HeadlineSize::Small))
195                            .child(
196                                Label::new("Add at least one provider to use AI-powered features.")
197                                    .color(Color::Muted),
198                            ),
199                    )
200                    .children(
201                        providers
202                            .into_iter()
203                            .map(|provider| self.render_provider_configuration(&provider, cx)),
204                    ),
205            )
206    }
207}