1use crate::{
2 provider::{
3 anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
4 copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider,
5 ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
6 },
7 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
8 LanguageModelProviderState,
9};
10use client::Client;
11use collections::BTreeMap;
12use gpui::{AppContext, EventEmitter, Global, Model, ModelContext};
13use std::sync::Arc;
14use ui::Context;
15
16pub fn init(client: Arc<Client>, cx: &mut AppContext) {
17 let registry = cx.new_model(|cx| {
18 let mut registry = LanguageModelRegistry::default();
19 register_language_model_providers(&mut registry, client, cx);
20 registry
21 });
22 cx.set_global(GlobalLanguageModelRegistry(registry));
23}
24
25fn register_language_model_providers(
26 registry: &mut LanguageModelRegistry,
27 client: Arc<Client>,
28 cx: &mut ModelContext<LanguageModelRegistry>,
29) {
30 use feature_flags::FeatureFlagAppExt;
31
32 registry.register_provider(
33 AnthropicLanguageModelProvider::new(client.http_client(), cx),
34 cx,
35 );
36 registry.register_provider(
37 OpenAiLanguageModelProvider::new(client.http_client(), cx),
38 cx,
39 );
40 registry.register_provider(
41 OllamaLanguageModelProvider::new(client.http_client(), cx),
42 cx,
43 );
44 registry.register_provider(
45 GoogleLanguageModelProvider::new(client.http_client(), cx),
46 cx,
47 );
48 registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
49
50 cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
51 let client = client.clone();
52 LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
53 if enabled {
54 registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
55 } else {
56 registry.unregister_provider(
57 &LanguageModelProviderId::from(
58 crate::provider::cloud::PROVIDER_NAME.to_string(),
59 ),
60 cx,
61 );
62 }
63 });
64 })
65 .detach();
66}
67
68struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
69
70impl Global for GlobalLanguageModelRegistry {}
71
72#[derive(Default)]
73pub struct LanguageModelRegistry {
74 active_model: Option<ActiveModel>,
75 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
76}
77
78pub struct ActiveModel {
79 provider: Arc<dyn LanguageModelProvider>,
80 model: Option<Arc<dyn LanguageModel>>,
81}
82
83pub struct ActiveModelChanged;
84
85impl EventEmitter<ActiveModelChanged> for LanguageModelRegistry {}
86
87impl LanguageModelRegistry {
88 pub fn global(cx: &AppContext) -> Model<Self> {
89 cx.global::<GlobalLanguageModelRegistry>().0.clone()
90 }
91
92 pub fn read_global(cx: &AppContext) -> &Self {
93 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
94 }
95
96 #[cfg(any(test, feature = "test-support"))]
97 pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider {
98 let fake_provider = crate::provider::fake::FakeLanguageModelProvider::default();
99 let registry = cx.new_model(|cx| {
100 let mut registry = Self::default();
101 registry.register_provider(fake_provider.clone(), cx);
102 let model = fake_provider.provided_models(cx)[0].clone();
103 registry.set_active_model(Some(model), cx);
104 registry
105 });
106 cx.set_global(GlobalLanguageModelRegistry(registry));
107 fake_provider
108 }
109
110 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
111 &mut self,
112 provider: T,
113 cx: &mut ModelContext<Self>,
114 ) {
115 let name = provider.id();
116
117 if let Some(subscription) = provider.subscribe(cx) {
118 subscription.detach();
119 }
120
121 self.providers.insert(name, Arc::new(provider));
122 cx.notify();
123 }
124
125 pub fn unregister_provider(
126 &mut self,
127 name: &LanguageModelProviderId,
128 cx: &mut ModelContext<Self>,
129 ) {
130 if self.providers.remove(name).is_some() {
131 cx.notify();
132 }
133 }
134
135 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
136 let zed_provider_id = LanguageModelProviderId(crate::provider::cloud::PROVIDER_ID.into());
137 let mut providers = Vec::with_capacity(self.providers.len());
138 if let Some(provider) = self.providers.get(&zed_provider_id) {
139 providers.push(provider.clone());
140 }
141 providers.extend(self.providers.values().filter_map(|p| {
142 if p.id() != zed_provider_id {
143 Some(p.clone())
144 } else {
145 None
146 }
147 }));
148 providers
149 }
150
151 pub fn available_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
152 self.providers
153 .values()
154 .flat_map(|provider| provider.provided_models(cx))
155 .collect()
156 }
157
158 pub fn provider(
159 &self,
160 name: &LanguageModelProviderId,
161 ) -> Option<Arc<dyn LanguageModelProvider>> {
162 self.providers.get(name).cloned()
163 }
164
165 pub fn select_active_model(
166 &mut self,
167 provider: &LanguageModelProviderId,
168 model_id: &LanguageModelId,
169 cx: &mut ModelContext<Self>,
170 ) {
171 let Some(provider) = self.provider(&provider) else {
172 return;
173 };
174
175 let models = provider.provided_models(cx);
176 if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
177 self.set_active_model(Some(model), cx);
178 }
179 }
180
181 pub fn set_active_provider(
182 &mut self,
183 provider: Option<Arc<dyn LanguageModelProvider>>,
184 cx: &mut ModelContext<Self>,
185 ) {
186 self.active_model = provider.map(|provider| ActiveModel {
187 provider,
188 model: None,
189 });
190 cx.emit(ActiveModelChanged);
191 }
192
193 pub fn set_active_model(
194 &mut self,
195 model: Option<Arc<dyn LanguageModel>>,
196 cx: &mut ModelContext<Self>,
197 ) {
198 if let Some(model) = model {
199 let provider_id = model.provider_id();
200 if let Some(provider) = self.providers.get(&provider_id).cloned() {
201 self.active_model = Some(ActiveModel {
202 provider,
203 model: Some(model),
204 });
205 cx.emit(ActiveModelChanged);
206 } else {
207 log::warn!("Active model's provider not found in registry");
208 }
209 } else {
210 self.active_model = None;
211 cx.emit(ActiveModelChanged);
212 }
213 }
214
215 pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
216 Some(self.active_model.as_ref()?.provider.clone())
217 }
218
219 pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
220 self.active_model.as_ref()?.model.clone()
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use crate::provider::fake::FakeLanguageModelProvider;
228
229 #[gpui::test]
230 fn test_register_providers(cx: &mut AppContext) {
231 let registry = cx.new_model(|_| LanguageModelRegistry::default());
232
233 registry.update(cx, |registry, cx| {
234 registry.register_provider(FakeLanguageModelProvider::default(), cx);
235 });
236
237 let providers = registry.read(cx).providers();
238 assert_eq!(providers.len(), 1);
239 assert_eq!(providers[0].id(), crate::provider::fake::provider_id());
240
241 registry.update(cx, |registry, cx| {
242 registry.unregister_provider(&crate::provider::fake::provider_id(), cx);
243 });
244
245 let providers = registry.read(cx).providers();
246 assert!(providers.is_empty());
247 }
248}