assistant_configuration.rs

  1use std::sync::Arc;
  2
  3use collections::HashMap;
  4use gpui::{AnyView, AppContext, EventEmitter, FocusHandle, FocusableView, Subscription};
  5use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
  6use ui::{prelude::*, ElevationIndex};
  7
  8pub struct AssistantConfiguration {
  9    focus_handle: FocusHandle,
 10    configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
 11    _registry_subscription: Subscription,
 12}
 13
 14impl AssistantConfiguration {
 15    pub fn new(cx: &mut ViewContext<Self>) -> Self {
 16        let focus_handle = cx.focus_handle();
 17
 18        let registry_subscription = cx.subscribe(
 19            &LanguageModelRegistry::global(cx),
 20            |this, _, event: &language_model::Event, cx| match event {
 21                language_model::Event::AddedProvider(provider_id) => {
 22                    let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
 23                    if let Some(provider) = provider {
 24                        this.add_provider_configuration_view(&provider, cx);
 25                    }
 26                }
 27                language_model::Event::RemovedProvider(provider_id) => {
 28                    this.remove_provider_configuration_view(provider_id);
 29                }
 30                _ => {}
 31            },
 32        );
 33
 34        let mut this = Self {
 35            focus_handle,
 36            configuration_views_by_provider: HashMap::default(),
 37            _registry_subscription: registry_subscription,
 38        };
 39        this.build_provider_configuration_views(cx);
 40        this
 41    }
 42
 43    fn build_provider_configuration_views(&mut self, cx: &mut ViewContext<Self>) {
 44        let providers = LanguageModelRegistry::read_global(cx).providers();
 45        for provider in providers {
 46            self.add_provider_configuration_view(&provider, cx);
 47        }
 48    }
 49
 50    fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
 51        self.configuration_views_by_provider.remove(provider_id);
 52    }
 53
 54    fn add_provider_configuration_view(
 55        &mut self,
 56        provider: &Arc<dyn LanguageModelProvider>,
 57        cx: &mut ViewContext<Self>,
 58    ) {
 59        let configuration_view = provider.configuration_view(cx);
 60        self.configuration_views_by_provider
 61            .insert(provider.id(), configuration_view);
 62    }
 63}
 64
 65impl FocusableView for AssistantConfiguration {
 66    fn focus_handle(&self, _: &AppContext) -> FocusHandle {
 67        self.focus_handle.clone()
 68    }
 69}
 70
 71pub enum AssistantConfigurationEvent {
 72    NewThread(Arc<dyn LanguageModelProvider>),
 73}
 74
 75impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
 76
 77impl AssistantConfiguration {
 78    fn render_provider_configuration(
 79        &mut self,
 80        provider: &Arc<dyn LanguageModelProvider>,
 81        cx: &mut ViewContext<Self>,
 82    ) -> impl IntoElement {
 83        let provider_id = provider.id().0.clone();
 84        let provider_name = provider.name().0.clone();
 85        let configuration_view = self
 86            .configuration_views_by_provider
 87            .get(&provider.id())
 88            .cloned();
 89
 90        v_flex()
 91            .gap_2()
 92            .child(
 93                h_flex()
 94                    .justify_between()
 95                    .child(Headline::new(provider_name.clone()).size(HeadlineSize::Small))
 96                    .when(provider.is_authenticated(cx), |parent| {
 97                        parent.child(
 98                            h_flex().justify_end().child(
 99                                Button::new(
100                                    SharedString::from(format!("new-thread-{provider_id}")),
101                                    "Open New Thread",
102                                )
103                                .icon_position(IconPosition::Start)
104                                .icon(IconName::Plus)
105                                .style(ButtonStyle::Filled)
106                                .layer(ElevationIndex::ModalSurface)
107                                .on_click(cx.listener({
108                                    let provider = provider.clone();
109                                    move |_this, _event, cx| {
110                                        cx.emit(AssistantConfigurationEvent::NewThread(
111                                            provider.clone(),
112                                        ))
113                                    }
114                                })),
115                            ),
116                        )
117                    }),
118            )
119            .child(
120                div()
121                    .p(DynamicSpacing::Base08.rems(cx))
122                    .bg(cx.theme().colors().surface_background)
123                    .border_1()
124                    .border_color(cx.theme().colors().border_variant)
125                    .rounded_md()
126                    .map(|parent| match configuration_view {
127                        Some(configuration_view) => parent.child(configuration_view),
128                        None => parent.child(div().child(Label::new(format!(
129                            "No configuration view for {provider_name}",
130                        )))),
131                    }),
132            )
133    }
134}
135
136impl Render for AssistantConfiguration {
137    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
138        let providers = LanguageModelRegistry::read_global(cx).providers();
139
140        v_flex()
141            .id("assistant-configuration")
142            .track_focus(&self.focus_handle(cx))
143            .bg(cx.theme().colors().editor_background)
144            .size_full()
145            .overflow_y_scroll()
146            .child(
147                v_flex()
148                    .p(DynamicSpacing::Base16.rems(cx))
149                    .mt_1()
150                    .gap_6()
151                    .flex_1()
152                    .children(
153                        providers
154                            .into_iter()
155                            .map(|provider| self.render_provider_configuration(&provider, cx)),
156                    ),
157            )
158    }
159}