1use client::Client;
2use collections::HashMap;
3use gpui::{AppContext, Global, Model, ModelContext};
4use std::sync::Arc;
5use ui::Context;
6
7use crate::{
8 provider::{
9 anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
10 ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
11 },
12 LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
13};
14
15pub fn init(client: Arc<Client>, cx: &mut AppContext) {
16 let registry = cx.new_model(|cx| {
17 let mut registry = LanguageModelRegistry::default();
18 register_language_model_providers(&mut registry, client, cx);
19 registry
20 });
21 cx.set_global(GlobalLanguageModelRegistry(registry));
22}
23
24fn register_language_model_providers(
25 registry: &mut LanguageModelRegistry,
26 client: Arc<Client>,
27 cx: &mut ModelContext<LanguageModelRegistry>,
28) {
29 use feature_flags::FeatureFlagAppExt;
30
31 registry.register_provider(
32 AnthropicLanguageModelProvider::new(client.http_client(), cx),
33 cx,
34 );
35 registry.register_provider(
36 OpenAiLanguageModelProvider::new(client.http_client(), cx),
37 cx,
38 );
39 registry.register_provider(
40 OllamaLanguageModelProvider::new(client.http_client(), cx),
41 cx,
42 );
43
44 cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
45 let client = client.clone();
46 LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
47 if enabled {
48 registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
49 } else {
50 registry.unregister_provider(
51 &LanguageModelProviderName::from(
52 crate::provider::cloud::PROVIDER_NAME.to_string(),
53 ),
54 cx,
55 );
56 }
57 });
58 })
59 .detach();
60}
61
62struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
63
64impl Global for GlobalLanguageModelRegistry {}
65
66#[derive(Default)]
67pub struct LanguageModelRegistry {
68 providers: HashMap<LanguageModelProviderName, Arc<dyn LanguageModelProvider>>,
69}
70
71impl LanguageModelRegistry {
72 pub fn global(cx: &AppContext) -> Model<Self> {
73 cx.global::<GlobalLanguageModelRegistry>().0.clone()
74 }
75
76 pub fn read_global(cx: &AppContext) -> &Self {
77 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
78 }
79
80 #[cfg(any(test, feature = "test-support"))]
81 pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider {
82 let fake_provider = crate::provider::fake::FakeLanguageModelProvider::default();
83 let registry = cx.new_model(|cx| {
84 let mut registry = Self::default();
85 registry.register_provider(fake_provider.clone(), cx);
86 registry
87 });
88 cx.set_global(GlobalLanguageModelRegistry(registry));
89 fake_provider
90 }
91
92 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
93 &mut self,
94 provider: T,
95 cx: &mut ModelContext<Self>,
96 ) {
97 let name = provider.name();
98
99 if let Some(subscription) = provider.subscribe(cx) {
100 subscription.detach();
101 }
102
103 self.providers.insert(name, Arc::new(provider));
104 cx.notify();
105 }
106
107 pub fn unregister_provider(
108 &mut self,
109 name: &LanguageModelProviderName,
110 cx: &mut ModelContext<Self>,
111 ) {
112 if self.providers.remove(name).is_some() {
113 cx.notify();
114 }
115 }
116
117 pub fn providers(
118 &self,
119 ) -> impl Iterator<Item = (&LanguageModelProviderName, &Arc<dyn LanguageModelProvider>)> {
120 self.providers.iter()
121 }
122
123 pub fn available_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
124 self.providers
125 .values()
126 .flat_map(|provider| provider.provided_models(cx))
127 .collect()
128 }
129
130 pub fn available_models_grouped_by_provider(
131 &self,
132 cx: &AppContext,
133 ) -> HashMap<LanguageModelProviderName, Vec<Arc<dyn LanguageModel>>> {
134 self.providers
135 .iter()
136 .map(|(name, provider)| (name.clone(), provider.provided_models(cx)))
137 .collect()
138 }
139
140 pub fn provider(
141 &self,
142 name: &LanguageModelProviderName,
143 ) -> Option<Arc<dyn LanguageModelProvider>> {
144 self.providers.get(name).cloned()
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use crate::provider::fake::FakeLanguageModelProvider;
152
153 #[gpui::test]
154 fn test_register_providers(cx: &mut AppContext) {
155 let registry = cx.new_model(|_| LanguageModelRegistry::default());
156
157 registry.update(cx, |registry, cx| {
158 registry.register_provider(FakeLanguageModelProvider::default(), cx);
159 });
160
161 let providers = registry.read(cx).providers().collect::<Vec<_>>();
162 assert_eq!(providers.len(), 1);
163 assert_eq!(providers[0].0, &crate::provider::fake::provider_name());
164
165 registry.update(cx, |registry, cx| {
166 registry.unregister_provider(&crate::provider::fake::provider_name(), cx);
167 });
168
169 let providers = registry.read(cx).providers().collect::<Vec<_>>();
170 assert!(providers.is_empty());
171 }
172}