1use client::Client;
2use collections::BTreeMap;
3use gpui::{AppContext, Global, Model, ModelContext};
4use std::sync::Arc;
5use ui::Context;
6
7use crate::{
8 provider::{
9 anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
10 ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
11 },
12 LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
13};
14
15pub fn init(client: Arc<Client>, cx: &mut AppContext) {
16 let registry = cx.new_model(|cx| {
17 let mut registry = LanguageModelRegistry::default();
18 register_language_model_providers(&mut registry, client, cx);
19 registry
20 });
21 cx.set_global(GlobalLanguageModelRegistry(registry));
22}
23
24fn register_language_model_providers(
25 registry: &mut LanguageModelRegistry,
26 client: Arc<Client>,
27 cx: &mut ModelContext<LanguageModelRegistry>,
28) {
29 use feature_flags::FeatureFlagAppExt;
30
31 registry.register_provider(
32 AnthropicLanguageModelProvider::new(client.http_client(), cx),
33 cx,
34 );
35 registry.register_provider(
36 OpenAiLanguageModelProvider::new(client.http_client(), cx),
37 cx,
38 );
39 registry.register_provider(
40 OllamaLanguageModelProvider::new(client.http_client(), cx),
41 cx,
42 );
43
44 cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
45 let client = client.clone();
46 LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
47 if enabled {
48 registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
49 } else {
50 registry.unregister_provider(
51 &LanguageModelProviderId::from(
52 crate::provider::cloud::PROVIDER_NAME.to_string(),
53 ),
54 cx,
55 );
56 }
57 });
58 })
59 .detach();
60}
61
62struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
63
64impl Global for GlobalLanguageModelRegistry {}
65
66#[derive(Default)]
67pub struct LanguageModelRegistry {
68 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
69}
70
71impl LanguageModelRegistry {
72 pub fn global(cx: &AppContext) -> Model<Self> {
73 cx.global::<GlobalLanguageModelRegistry>().0.clone()
74 }
75
76 pub fn read_global(cx: &AppContext) -> &Self {
77 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
78 }
79
80 #[cfg(any(test, feature = "test-support"))]
81 pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider {
82 let fake_provider = crate::provider::fake::FakeLanguageModelProvider::default();
83 let registry = cx.new_model(|cx| {
84 let mut registry = Self::default();
85 registry.register_provider(fake_provider.clone(), cx);
86 registry
87 });
88 cx.set_global(GlobalLanguageModelRegistry(registry));
89 fake_provider
90 }
91
92 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
93 &mut self,
94 provider: T,
95 cx: &mut ModelContext<Self>,
96 ) {
97 let name = provider.id();
98
99 if let Some(subscription) = provider.subscribe(cx) {
100 subscription.detach();
101 }
102
103 self.providers.insert(name, Arc::new(provider));
104 cx.notify();
105 }
106
107 pub fn unregister_provider(
108 &mut self,
109 name: &LanguageModelProviderId,
110 cx: &mut ModelContext<Self>,
111 ) {
112 if self.providers.remove(name).is_some() {
113 cx.notify();
114 }
115 }
116
117 pub fn providers(&self) -> impl Iterator<Item = &Arc<dyn LanguageModelProvider>> {
118 self.providers.values()
119 }
120
121 pub fn available_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
122 self.providers
123 .values()
124 .flat_map(|provider| provider.provided_models(cx))
125 .collect()
126 }
127
128 pub fn provider(
129 &self,
130 name: &LanguageModelProviderId,
131 ) -> Option<Arc<dyn LanguageModelProvider>> {
132 self.providers.get(name).cloned()
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use crate::provider::fake::FakeLanguageModelProvider;
140
141 #[gpui::test]
142 fn test_register_providers(cx: &mut AppContext) {
143 let registry = cx.new_model(|_| LanguageModelRegistry::default());
144
145 registry.update(cx, |registry, cx| {
146 registry.register_provider(FakeLanguageModelProvider::default(), cx);
147 });
148
149 let providers = registry.read(cx).providers().collect::<Vec<_>>();
150 assert_eq!(providers.len(), 1);
151 assert_eq!(providers[0].id(), crate::provider::fake::provider_id());
152
153 registry.update(cx, |registry, cx| {
154 registry.unregister_provider(&crate::provider::fake::provider_id(), cx);
155 });
156
157 let providers = registry.read(cx).providers().collect::<Vec<_>>();
158 assert!(providers.is_empty());
159 }
160}