1use crate::{
2 provider::{
3 anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
4 copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider,
5 ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
6 },
7 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
8 LanguageModelProviderState,
9};
10use client::{Client, UserStore};
11use collections::BTreeMap;
12use gpui::{AppContext, EventEmitter, Global, Model, ModelContext};
13use std::sync::Arc;
14use ui::Context;
15
16pub fn init(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) {
17 let registry = cx.new_model(|cx| {
18 let mut registry = LanguageModelRegistry::default();
19 register_language_model_providers(&mut registry, user_store, client, cx);
20 registry
21 });
22 cx.set_global(GlobalLanguageModelRegistry(registry));
23}
24
25fn register_language_model_providers(
26 registry: &mut LanguageModelRegistry,
27 user_store: Model<UserStore>,
28 client: Arc<Client>,
29 cx: &mut ModelContext<LanguageModelRegistry>,
30) {
31 use feature_flags::FeatureFlagAppExt;
32
33 registry.register_provider(
34 AnthropicLanguageModelProvider::new(client.http_client(), cx),
35 cx,
36 );
37 registry.register_provider(
38 OpenAiLanguageModelProvider::new(client.http_client(), cx),
39 cx,
40 );
41 registry.register_provider(
42 OllamaLanguageModelProvider::new(client.http_client(), cx),
43 cx,
44 );
45 registry.register_provider(
46 GoogleLanguageModelProvider::new(client.http_client(), cx),
47 cx,
48 );
49 registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
50
51 cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
52 let user_store = user_store.clone();
53 let client = client.clone();
54 LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
55 if enabled {
56 registry.register_provider(
57 CloudLanguageModelProvider::new(user_store.clone(), client.clone(), cx),
58 cx,
59 );
60 } else {
61 registry.unregister_provider(
62 LanguageModelProviderId::from(crate::provider::cloud::PROVIDER_ID.to_string()),
63 cx,
64 );
65 }
66 });
67 })
68 .detach();
69}
70
71struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
72
73impl Global for GlobalLanguageModelRegistry {}
74
75#[derive(Default)]
76pub struct LanguageModelRegistry {
77 active_model: Option<ActiveModel>,
78 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
79 inline_alternatives: Vec<Arc<dyn LanguageModel>>,
80}
81
82pub struct ActiveModel {
83 provider: Arc<dyn LanguageModelProvider>,
84 model: Option<Arc<dyn LanguageModel>>,
85}
86
87pub enum Event {
88 ActiveModelChanged,
89 ProviderStateChanged,
90 AddedProvider(LanguageModelProviderId),
91 RemovedProvider(LanguageModelProviderId),
92}
93
94impl EventEmitter<Event> for LanguageModelRegistry {}
95
96impl LanguageModelRegistry {
97 pub fn global(cx: &AppContext) -> Model<Self> {
98 cx.global::<GlobalLanguageModelRegistry>().0.clone()
99 }
100
101 pub fn read_global(cx: &AppContext) -> &Self {
102 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
103 }
104
105 #[cfg(any(test, feature = "test-support"))]
106 pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider {
107 let fake_provider = crate::provider::fake::FakeLanguageModelProvider;
108 let registry = cx.new_model(|cx| {
109 let mut registry = Self::default();
110 registry.register_provider(fake_provider.clone(), cx);
111 let model = fake_provider.provided_models(cx)[0].clone();
112 registry.set_active_model(Some(model), cx);
113 registry
114 });
115 cx.set_global(GlobalLanguageModelRegistry(registry));
116 fake_provider
117 }
118
119 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
120 &mut self,
121 provider: T,
122 cx: &mut ModelContext<Self>,
123 ) {
124 let id = provider.id();
125
126 let subscription = provider.subscribe(cx, |_, cx| {
127 cx.emit(Event::ProviderStateChanged);
128 });
129 if let Some(subscription) = subscription {
130 subscription.detach();
131 }
132
133 self.providers.insert(id.clone(), Arc::new(provider));
134 cx.emit(Event::AddedProvider(id));
135 }
136
137 pub fn unregister_provider(
138 &mut self,
139 id: LanguageModelProviderId,
140 cx: &mut ModelContext<Self>,
141 ) {
142 if self.providers.remove(&id).is_some() {
143 cx.emit(Event::RemovedProvider(id));
144 }
145 }
146
147 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
148 let zed_provider_id = LanguageModelProviderId(crate::provider::cloud::PROVIDER_ID.into());
149 let mut providers = Vec::with_capacity(self.providers.len());
150 if let Some(provider) = self.providers.get(&zed_provider_id) {
151 providers.push(provider.clone());
152 }
153 providers.extend(self.providers.values().filter_map(|p| {
154 if p.id() != zed_provider_id {
155 Some(p.clone())
156 } else {
157 None
158 }
159 }));
160 providers
161 }
162
163 pub fn available_models<'a>(
164 &'a self,
165 cx: &'a AppContext,
166 ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
167 self.providers
168 .values()
169 .flat_map(|provider| provider.provided_models(cx))
170 }
171
172 pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
173 self.providers.get(id).cloned()
174 }
175
176 pub fn select_active_model(
177 &mut self,
178 provider: &LanguageModelProviderId,
179 model_id: &LanguageModelId,
180 cx: &mut ModelContext<Self>,
181 ) {
182 let Some(provider) = self.provider(provider) else {
183 return;
184 };
185
186 let models = provider.provided_models(cx);
187 if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
188 self.set_active_model(Some(model), cx);
189 }
190 }
191
192 pub fn set_active_provider(
193 &mut self,
194 provider: Option<Arc<dyn LanguageModelProvider>>,
195 cx: &mut ModelContext<Self>,
196 ) {
197 self.active_model = provider.map(|provider| ActiveModel {
198 provider,
199 model: None,
200 });
201 cx.emit(Event::ActiveModelChanged);
202 }
203
204 pub fn set_active_model(
205 &mut self,
206 model: Option<Arc<dyn LanguageModel>>,
207 cx: &mut ModelContext<Self>,
208 ) {
209 if let Some(model) = model {
210 let provider_id = model.provider_id();
211 if let Some(provider) = self.providers.get(&provider_id).cloned() {
212 self.active_model = Some(ActiveModel {
213 provider,
214 model: Some(model),
215 });
216 cx.emit(Event::ActiveModelChanged);
217 } else {
218 log::warn!("Active model's provider not found in registry");
219 }
220 } else {
221 self.active_model = None;
222 cx.emit(Event::ActiveModelChanged);
223 }
224 }
225
226 pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
227 Some(self.active_model.as_ref()?.provider.clone())
228 }
229
230 pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
231 self.active_model.as_ref()?.model.clone()
232 }
233
234 /// Selects and sets the inline alternatives for language models based on
235 /// provider name and id.
236 pub fn select_inline_alternative_models(
237 &mut self,
238 alternatives: impl IntoIterator<Item = (LanguageModelProviderId, LanguageModelId)>,
239 cx: &mut ModelContext<Self>,
240 ) {
241 let mut selected_alternatives = Vec::new();
242
243 for (provider_id, model_id) in alternatives {
244 if let Some(provider) = self.providers.get(&provider_id) {
245 if let Some(model) = provider
246 .provided_models(cx)
247 .iter()
248 .find(|m| m.id() == model_id)
249 {
250 selected_alternatives.push(model.clone());
251 }
252 }
253 }
254
255 self.inline_alternatives = selected_alternatives;
256 }
257
258 /// The models to use for inline assists. Returns the union of the active
259 /// model and all inline alternatives. When there are multiple models, the
260 /// user will be able to cycle through results.
261 pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
262 &self.inline_alternatives
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use crate::provider::fake::FakeLanguageModelProvider;
270
271 #[gpui::test]
272 fn test_register_providers(cx: &mut AppContext) {
273 let registry = cx.new_model(|_| LanguageModelRegistry::default());
274
275 registry.update(cx, |registry, cx| {
276 registry.register_provider(FakeLanguageModelProvider, cx);
277 });
278
279 let providers = registry.read(cx).providers();
280 assert_eq!(providers.len(), 1);
281 assert_eq!(providers[0].id(), crate::provider::fake::provider_id());
282
283 registry.update(cx, |registry, cx| {
284 registry.unregister_provider(crate::provider::fake::provider_id(), cx);
285 });
286
287 let providers = registry.read(cx).providers();
288 assert!(providers.is_empty());
289 }
290}