1use crate::{
2 provider::{
3 anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
4 google::GoogleLanguageModelProvider, ollama::OllamaLanguageModelProvider,
5 open_ai::OpenAiLanguageModelProvider,
6 },
7 LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
8};
9use client::Client;
10use collections::BTreeMap;
11use gpui::{AppContext, Global, Model, ModelContext};
12use std::sync::Arc;
13use ui::Context;
14
15pub fn init(client: Arc<Client>, cx: &mut AppContext) {
16 let registry = cx.new_model(|cx| {
17 let mut registry = LanguageModelRegistry::default();
18 register_language_model_providers(&mut registry, client, cx);
19 registry
20 });
21 cx.set_global(GlobalLanguageModelRegistry(registry));
22}
23
24fn register_language_model_providers(
25 registry: &mut LanguageModelRegistry,
26 client: Arc<Client>,
27 cx: &mut ModelContext<LanguageModelRegistry>,
28) {
29 use feature_flags::FeatureFlagAppExt;
30
31 registry.register_provider(
32 AnthropicLanguageModelProvider::new(client.http_client(), cx),
33 cx,
34 );
35 registry.register_provider(
36 OpenAiLanguageModelProvider::new(client.http_client(), cx),
37 cx,
38 );
39 registry.register_provider(
40 OllamaLanguageModelProvider::new(client.http_client(), cx),
41 cx,
42 );
43 registry.register_provider(
44 GoogleLanguageModelProvider::new(client.http_client(), cx),
45 cx,
46 );
47
48 cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
49 let client = client.clone();
50 LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
51 if enabled {
52 registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
53 } else {
54 registry.unregister_provider(
55 &LanguageModelProviderId::from(
56 crate::provider::cloud::PROVIDER_NAME.to_string(),
57 ),
58 cx,
59 );
60 }
61 });
62 })
63 .detach();
64}
65
66struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
67
68impl Global for GlobalLanguageModelRegistry {}
69
70#[derive(Default)]
71pub struct LanguageModelRegistry {
72 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
73}
74
75impl LanguageModelRegistry {
76 pub fn global(cx: &AppContext) -> Model<Self> {
77 cx.global::<GlobalLanguageModelRegistry>().0.clone()
78 }
79
80 pub fn read_global(cx: &AppContext) -> &Self {
81 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
82 }
83
84 #[cfg(any(test, feature = "test-support"))]
85 pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider {
86 let fake_provider = crate::provider::fake::FakeLanguageModelProvider::default();
87 let registry = cx.new_model(|cx| {
88 let mut registry = Self::default();
89 registry.register_provider(fake_provider.clone(), cx);
90 registry
91 });
92 cx.set_global(GlobalLanguageModelRegistry(registry));
93 fake_provider
94 }
95
96 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
97 &mut self,
98 provider: T,
99 cx: &mut ModelContext<Self>,
100 ) {
101 let name = provider.id();
102
103 if let Some(subscription) = provider.subscribe(cx) {
104 subscription.detach();
105 }
106
107 self.providers.insert(name, Arc::new(provider));
108 cx.notify();
109 }
110
111 pub fn unregister_provider(
112 &mut self,
113 name: &LanguageModelProviderId,
114 cx: &mut ModelContext<Self>,
115 ) {
116 if self.providers.remove(name).is_some() {
117 cx.notify();
118 }
119 }
120
121 pub fn providers(&self) -> impl Iterator<Item = &Arc<dyn LanguageModelProvider>> {
122 self.providers.values()
123 }
124
125 pub fn available_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
126 self.providers
127 .values()
128 .flat_map(|provider| provider.provided_models(cx))
129 .collect()
130 }
131
132 pub fn provider(
133 &self,
134 name: &LanguageModelProviderId,
135 ) -> Option<Arc<dyn LanguageModelProvider>> {
136 self.providers.get(name).cloned()
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use crate::provider::fake::FakeLanguageModelProvider;
144
145 #[gpui::test]
146 fn test_register_providers(cx: &mut AppContext) {
147 let registry = cx.new_model(|_| LanguageModelRegistry::default());
148
149 registry.update(cx, |registry, cx| {
150 registry.register_provider(FakeLanguageModelProvider::default(), cx);
151 });
152
153 let providers = registry.read(cx).providers().collect::<Vec<_>>();
154 assert_eq!(providers.len(), 1);
155 assert_eq!(providers[0].id(), crate::provider::fake::provider_id());
156
157 registry.update(cx, |registry, cx| {
158 registry.unregister_provider(&crate::provider::fake::provider_id(), cx);
159 });
160
161 let providers = registry.read(cx).providers().collect::<Vec<_>>();
162 assert!(providers.is_empty());
163 }
164}