assistant_configuration.rs

  1use std::sync::Arc;
  2
  3use collections::HashMap;
  4use gpui::{Action, AnyView, AppContext, EventEmitter, FocusHandle, FocusableView, Subscription};
  5use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
  6use ui::{prelude::*, ElevationIndex};
  7use zed_actions::assistant::DeployPromptLibrary;
  8
  9pub struct AssistantConfiguration {
 10    focus_handle: FocusHandle,
 11    configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
 12    _registry_subscription: Subscription,
 13}
 14
 15impl AssistantConfiguration {
 16    pub fn new(cx: &mut ViewContext<Self>) -> Self {
 17        let focus_handle = cx.focus_handle();
 18
 19        let registry_subscription = cx.subscribe(
 20            &LanguageModelRegistry::global(cx),
 21            |this, _, event: &language_model::Event, cx| match event {
 22                language_model::Event::AddedProvider(provider_id) => {
 23                    let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
 24                    if let Some(provider) = provider {
 25                        this.add_provider_configuration_view(&provider, cx);
 26                    }
 27                }
 28                language_model::Event::RemovedProvider(provider_id) => {
 29                    this.remove_provider_configuration_view(provider_id);
 30                }
 31                _ => {}
 32            },
 33        );
 34
 35        let mut this = Self {
 36            focus_handle,
 37            configuration_views_by_provider: HashMap::default(),
 38            _registry_subscription: registry_subscription,
 39        };
 40        this.build_provider_configuration_views(cx);
 41        this
 42    }
 43
 44    fn build_provider_configuration_views(&mut self, cx: &mut ViewContext<Self>) {
 45        let providers = LanguageModelRegistry::read_global(cx).providers();
 46        for provider in providers {
 47            self.add_provider_configuration_view(&provider, cx);
 48        }
 49    }
 50
 51    fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
 52        self.configuration_views_by_provider.remove(provider_id);
 53    }
 54
 55    fn add_provider_configuration_view(
 56        &mut self,
 57        provider: &Arc<dyn LanguageModelProvider>,
 58        cx: &mut ViewContext<Self>,
 59    ) {
 60        let configuration_view = provider.configuration_view(cx);
 61        self.configuration_views_by_provider
 62            .insert(provider.id(), configuration_view);
 63    }
 64}
 65
 66impl FocusableView for AssistantConfiguration {
 67    fn focus_handle(&self, _: &AppContext) -> FocusHandle {
 68        self.focus_handle.clone()
 69    }
 70}
 71
 72pub enum AssistantConfigurationEvent {
 73    NewThread(Arc<dyn LanguageModelProvider>),
 74}
 75
 76impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
 77
 78impl AssistantConfiguration {
 79    fn render_provider_configuration(
 80        &mut self,
 81        provider: &Arc<dyn LanguageModelProvider>,
 82        cx: &mut ViewContext<Self>,
 83    ) -> impl IntoElement {
 84        let provider_id = provider.id().0.clone();
 85        let provider_name = provider.name().0.clone();
 86        let configuration_view = self
 87            .configuration_views_by_provider
 88            .get(&provider.id())
 89            .cloned();
 90
 91        v_flex()
 92            .gap_2()
 93            .child(
 94                h_flex()
 95                    .justify_between()
 96                    .child(Headline::new(provider_name.clone()).size(HeadlineSize::Small))
 97                    .when(provider.is_authenticated(cx), |parent| {
 98                        parent.child(
 99                            h_flex().justify_end().child(
100                                Button::new(
101                                    SharedString::from(format!("new-thread-{provider_id}")),
102                                    "Open New Thread",
103                                )
104                                .icon_position(IconPosition::Start)
105                                .icon(IconName::Plus)
106                                .style(ButtonStyle::Filled)
107                                .layer(ElevationIndex::ModalSurface)
108                                .on_click(cx.listener({
109                                    let provider = provider.clone();
110                                    move |_this, _event, cx| {
111                                        cx.emit(AssistantConfigurationEvent::NewThread(
112                                            provider.clone(),
113                                        ))
114                                    }
115                                })),
116                            ),
117                        )
118                    }),
119            )
120            .child(
121                div()
122                    .p(DynamicSpacing::Base08.rems(cx))
123                    .bg(cx.theme().colors().surface_background)
124                    .border_1()
125                    .border_color(cx.theme().colors().border_variant)
126                    .rounded_md()
127                    .map(|parent| match configuration_view {
128                        Some(configuration_view) => parent.child(configuration_view),
129                        None => parent.child(div().child(Label::new(format!(
130                            "No configuration view for {provider_name}",
131                        )))),
132                    }),
133            )
134    }
135}
136
137impl Render for AssistantConfiguration {
138    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
139        let providers = LanguageModelRegistry::read_global(cx).providers();
140
141        v_flex()
142            .id("assistant-configuration")
143            .track_focus(&self.focus_handle(cx))
144            .bg(cx.theme().colors().editor_background)
145            .size_full()
146            .overflow_y_scroll()
147            .child(
148                h_flex().p(DynamicSpacing::Base16.rems(cx)).child(
149                    Button::new("open-prompt-library", "Open Prompt Library")
150                        .style(ButtonStyle::Filled)
151                        .full_width()
152                        .icon(IconName::Book)
153                        .icon_size(IconSize::Small)
154                        .icon_position(IconPosition::Start)
155                        .on_click(|_event, cx| {
156                            cx.dispatch_action(DeployPromptLibrary.boxed_clone())
157                        }),
158                ),
159            )
160            .child(
161                v_flex()
162                    .p(DynamicSpacing::Base16.rems(cx))
163                    .mt_1()
164                    .gap_6()
165                    .flex_1()
166                    .children(
167                        providers
168                            .into_iter()
169                            .map(|provider| self.render_provider_configuration(&provider, cx)),
170                    ),
171            )
172    }
173}