1use crate::{
2 provider::{
3 anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
4 copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider,
5 ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
6 },
7 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
8 LanguageModelProviderState,
9};
10use client::{Client, UserStore};
11use collections::BTreeMap;
12use gpui::{AppContext, EventEmitter, Global, Model, ModelContext};
13use std::sync::Arc;
14use ui::Context;
15
16pub fn init(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) {
17 let registry = cx.new_model(|cx| {
18 let mut registry = LanguageModelRegistry::default();
19 register_language_model_providers(&mut registry, user_store, client, cx);
20 registry
21 });
22 cx.set_global(GlobalLanguageModelRegistry(registry));
23}
24
25fn register_language_model_providers(
26 registry: &mut LanguageModelRegistry,
27 user_store: Model<UserStore>,
28 client: Arc<Client>,
29 cx: &mut ModelContext<LanguageModelRegistry>,
30) {
31 use feature_flags::FeatureFlagAppExt;
32
33 registry.register_provider(
34 AnthropicLanguageModelProvider::new(client.http_client(), cx),
35 cx,
36 );
37 registry.register_provider(
38 OpenAiLanguageModelProvider::new(client.http_client(), cx),
39 cx,
40 );
41 registry.register_provider(
42 OllamaLanguageModelProvider::new(client.http_client(), cx),
43 cx,
44 );
45 registry.register_provider(
46 GoogleLanguageModelProvider::new(client.http_client(), cx),
47 cx,
48 );
49 registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
50
51 cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
52 let user_store = user_store.clone();
53 let client = client.clone();
54 LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
55 if enabled {
56 registry.register_provider(
57 CloudLanguageModelProvider::new(user_store.clone(), client.clone(), cx),
58 cx,
59 );
60 } else {
61 registry.unregister_provider(
62 LanguageModelProviderId::from(crate::provider::cloud::PROVIDER_ID.to_string()),
63 cx,
64 );
65 }
66 });
67 })
68 .detach();
69}
70
71struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
72
73impl Global for GlobalLanguageModelRegistry {}
74
75#[derive(Default)]
76pub struct LanguageModelRegistry {
77 active_model: Option<ActiveModel>,
78 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
79}
80
81pub struct ActiveModel {
82 provider: Arc<dyn LanguageModelProvider>,
83 model: Option<Arc<dyn LanguageModel>>,
84}
85
86pub enum Event {
87 ActiveModelChanged,
88 ProviderStateChanged,
89 AddedProvider(LanguageModelProviderId),
90 RemovedProvider(LanguageModelProviderId),
91}
92
93impl EventEmitter<Event> for LanguageModelRegistry {}
94
95impl LanguageModelRegistry {
96 pub fn global(cx: &AppContext) -> Model<Self> {
97 cx.global::<GlobalLanguageModelRegistry>().0.clone()
98 }
99
100 pub fn read_global(cx: &AppContext) -> &Self {
101 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
102 }
103
104 #[cfg(any(test, feature = "test-support"))]
105 pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider {
106 let fake_provider = crate::provider::fake::FakeLanguageModelProvider;
107 let registry = cx.new_model(|cx| {
108 let mut registry = Self::default();
109 registry.register_provider(fake_provider.clone(), cx);
110 let model = fake_provider.provided_models(cx)[0].clone();
111 registry.set_active_model(Some(model), cx);
112 registry
113 });
114 cx.set_global(GlobalLanguageModelRegistry(registry));
115 fake_provider
116 }
117
118 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
119 &mut self,
120 provider: T,
121 cx: &mut ModelContext<Self>,
122 ) {
123 let id = provider.id();
124
125 let subscription = provider.subscribe(cx, |_, cx| {
126 cx.emit(Event::ProviderStateChanged);
127 });
128 if let Some(subscription) = subscription {
129 subscription.detach();
130 }
131
132 self.providers.insert(id.clone(), Arc::new(provider));
133 cx.emit(Event::AddedProvider(id));
134 }
135
136 pub fn unregister_provider(
137 &mut self,
138 id: LanguageModelProviderId,
139 cx: &mut ModelContext<Self>,
140 ) {
141 if self.providers.remove(&id).is_some() {
142 cx.emit(Event::RemovedProvider(id));
143 }
144 }
145
146 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
147 let zed_provider_id = LanguageModelProviderId(crate::provider::cloud::PROVIDER_ID.into());
148 let mut providers = Vec::with_capacity(self.providers.len());
149 if let Some(provider) = self.providers.get(&zed_provider_id) {
150 providers.push(provider.clone());
151 }
152 providers.extend(self.providers.values().filter_map(|p| {
153 if p.id() != zed_provider_id {
154 Some(p.clone())
155 } else {
156 None
157 }
158 }));
159 providers
160 }
161
162 pub fn available_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
163 self.providers
164 .values()
165 .flat_map(|provider| provider.provided_models(cx))
166 .collect()
167 }
168
169 pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
170 self.providers.get(id).cloned()
171 }
172
173 pub fn select_active_model(
174 &mut self,
175 provider: &LanguageModelProviderId,
176 model_id: &LanguageModelId,
177 cx: &mut ModelContext<Self>,
178 ) {
179 let Some(provider) = self.provider(&provider) else {
180 return;
181 };
182
183 let models = provider.provided_models(cx);
184 if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
185 self.set_active_model(Some(model), cx);
186 }
187 }
188
189 pub fn set_active_provider(
190 &mut self,
191 provider: Option<Arc<dyn LanguageModelProvider>>,
192 cx: &mut ModelContext<Self>,
193 ) {
194 self.active_model = provider.map(|provider| ActiveModel {
195 provider,
196 model: None,
197 });
198 cx.emit(Event::ActiveModelChanged);
199 }
200
201 pub fn set_active_model(
202 &mut self,
203 model: Option<Arc<dyn LanguageModel>>,
204 cx: &mut ModelContext<Self>,
205 ) {
206 if let Some(model) = model {
207 let provider_id = model.provider_id();
208 if let Some(provider) = self.providers.get(&provider_id).cloned() {
209 self.active_model = Some(ActiveModel {
210 provider,
211 model: Some(model),
212 });
213 cx.emit(Event::ActiveModelChanged);
214 } else {
215 log::warn!("Active model's provider not found in registry");
216 }
217 } else {
218 self.active_model = None;
219 cx.emit(Event::ActiveModelChanged);
220 }
221 }
222
223 pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
224 Some(self.active_model.as_ref()?.provider.clone())
225 }
226
227 pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
228 self.active_model.as_ref()?.model.clone()
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use crate::provider::fake::FakeLanguageModelProvider;
236
237 #[gpui::test]
238 fn test_register_providers(cx: &mut AppContext) {
239 let registry = cx.new_model(|_| LanguageModelRegistry::default());
240
241 registry.update(cx, |registry, cx| {
242 registry.register_provider(FakeLanguageModelProvider, cx);
243 });
244
245 let providers = registry.read(cx).providers();
246 assert_eq!(providers.len(), 1);
247 assert_eq!(providers[0].id(), crate::provider::fake::provider_id());
248
249 registry.update(cx, |registry, cx| {
250 registry.unregister_provider(crate::provider::fake::provider_id(), cx);
251 });
252
253 let providers = registry.read(cx).providers();
254 assert!(providers.is_empty());
255 }
256}