1use crate::{
2 provider::{
3 anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
4 copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider,
5 ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
6 },
7 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
8 LanguageModelProviderState,
9};
10use client::Client;
11use collections::BTreeMap;
12use gpui::{AppContext, EventEmitter, Global, Model, ModelContext};
13use std::sync::Arc;
14use ui::Context;
15
16pub fn init(client: Arc<Client>, cx: &mut AppContext) {
17 let registry = cx.new_model(|cx| {
18 let mut registry = LanguageModelRegistry::default();
19 register_language_model_providers(&mut registry, client, cx);
20 registry
21 });
22 cx.set_global(GlobalLanguageModelRegistry(registry));
23}
24
25fn register_language_model_providers(
26 registry: &mut LanguageModelRegistry,
27 client: Arc<Client>,
28 cx: &mut ModelContext<LanguageModelRegistry>,
29) {
30 use feature_flags::FeatureFlagAppExt;
31
32 registry.register_provider(
33 AnthropicLanguageModelProvider::new(client.http_client(), cx),
34 cx,
35 );
36 registry.register_provider(
37 OpenAiLanguageModelProvider::new(client.http_client(), cx),
38 cx,
39 );
40 registry.register_provider(
41 OllamaLanguageModelProvider::new(client.http_client(), cx),
42 cx,
43 );
44 registry.register_provider(
45 GoogleLanguageModelProvider::new(client.http_client(), cx),
46 cx,
47 );
48 registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
49
50 cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
51 let client = client.clone();
52 LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
53 if enabled {
54 registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
55 } else {
56 registry.unregister_provider(
57 LanguageModelProviderId::from(crate::provider::cloud::PROVIDER_ID.to_string()),
58 cx,
59 );
60 }
61 });
62 })
63 .detach();
64}
65
66struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
67
68impl Global for GlobalLanguageModelRegistry {}
69
70#[derive(Default)]
71pub struct LanguageModelRegistry {
72 active_model: Option<ActiveModel>,
73 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
74}
75
76pub struct ActiveModel {
77 provider: Arc<dyn LanguageModelProvider>,
78 model: Option<Arc<dyn LanguageModel>>,
79}
80
81pub enum Event {
82 ActiveModelChanged,
83 ProviderStateChanged,
84 AddedProvider(LanguageModelProviderId),
85 RemovedProvider(LanguageModelProviderId),
86}
87
88impl EventEmitter<Event> for LanguageModelRegistry {}
89
90impl LanguageModelRegistry {
91 pub fn global(cx: &AppContext) -> Model<Self> {
92 cx.global::<GlobalLanguageModelRegistry>().0.clone()
93 }
94
95 pub fn read_global(cx: &AppContext) -> &Self {
96 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
97 }
98
99 #[cfg(any(test, feature = "test-support"))]
100 pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider {
101 let fake_provider = crate::provider::fake::FakeLanguageModelProvider::default();
102 let registry = cx.new_model(|cx| {
103 let mut registry = Self::default();
104 registry.register_provider(fake_provider.clone(), cx);
105 let model = fake_provider.provided_models(cx)[0].clone();
106 registry.set_active_model(Some(model), cx);
107 registry
108 });
109 cx.set_global(GlobalLanguageModelRegistry(registry));
110 fake_provider
111 }
112
113 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
114 &mut self,
115 provider: T,
116 cx: &mut ModelContext<Self>,
117 ) {
118 let id = provider.id();
119
120 let subscription = provider.subscribe(cx, |_, cx| {
121 cx.emit(Event::ProviderStateChanged);
122 });
123 if let Some(subscription) = subscription {
124 subscription.detach();
125 }
126
127 self.providers.insert(id.clone(), Arc::new(provider));
128 cx.emit(Event::AddedProvider(id));
129 }
130
131 pub fn unregister_provider(
132 &mut self,
133 id: LanguageModelProviderId,
134 cx: &mut ModelContext<Self>,
135 ) {
136 if self.providers.remove(&id).is_some() {
137 cx.emit(Event::RemovedProvider(id));
138 }
139 }
140
141 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
142 let zed_provider_id = LanguageModelProviderId(crate::provider::cloud::PROVIDER_ID.into());
143 let mut providers = Vec::with_capacity(self.providers.len());
144 if let Some(provider) = self.providers.get(&zed_provider_id) {
145 providers.push(provider.clone());
146 }
147 providers.extend(self.providers.values().filter_map(|p| {
148 if p.id() != zed_provider_id {
149 Some(p.clone())
150 } else {
151 None
152 }
153 }));
154 providers
155 }
156
157 pub fn available_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
158 self.providers
159 .values()
160 .flat_map(|provider| provider.provided_models(cx))
161 .collect()
162 }
163
164 pub fn provider(
165 &self,
166 name: &LanguageModelProviderId,
167 ) -> Option<Arc<dyn LanguageModelProvider>> {
168 self.providers.get(name).cloned()
169 }
170
171 pub fn select_active_model(
172 &mut self,
173 provider: &LanguageModelProviderId,
174 model_id: &LanguageModelId,
175 cx: &mut ModelContext<Self>,
176 ) {
177 let Some(provider) = self.provider(&provider) else {
178 return;
179 };
180
181 let models = provider.provided_models(cx);
182 if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
183 self.set_active_model(Some(model), cx);
184 }
185 }
186
187 pub fn set_active_provider(
188 &mut self,
189 provider: Option<Arc<dyn LanguageModelProvider>>,
190 cx: &mut ModelContext<Self>,
191 ) {
192 self.active_model = provider.map(|provider| ActiveModel {
193 provider,
194 model: None,
195 });
196 cx.emit(Event::ActiveModelChanged);
197 }
198
199 pub fn set_active_model(
200 &mut self,
201 model: Option<Arc<dyn LanguageModel>>,
202 cx: &mut ModelContext<Self>,
203 ) {
204 if let Some(model) = model {
205 let provider_id = model.provider_id();
206 if let Some(provider) = self.providers.get(&provider_id).cloned() {
207 self.active_model = Some(ActiveModel {
208 provider,
209 model: Some(model),
210 });
211 cx.emit(Event::ActiveModelChanged);
212 } else {
213 log::warn!("Active model's provider not found in registry");
214 }
215 } else {
216 self.active_model = None;
217 cx.emit(Event::ActiveModelChanged);
218 }
219 }
220
221 pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
222 Some(self.active_model.as_ref()?.provider.clone())
223 }
224
225 pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
226 self.active_model.as_ref()?.model.clone()
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use crate::provider::fake::FakeLanguageModelProvider;
234
235 #[gpui::test]
236 fn test_register_providers(cx: &mut AppContext) {
237 let registry = cx.new_model(|_| LanguageModelRegistry::default());
238
239 registry.update(cx, |registry, cx| {
240 registry.register_provider(FakeLanguageModelProvider::default(), cx);
241 });
242
243 let providers = registry.read(cx).providers();
244 assert_eq!(providers.len(), 1);
245 assert_eq!(providers[0].id(), crate::provider::fake::provider_id());
246
247 registry.update(cx, |registry, cx| {
248 registry.unregister_provider(crate::provider::fake::provider_id(), cx);
249 });
250
251 let providers = registry.read(cx).providers();
252 assert!(providers.is_empty());
253 }
254}