1use crate::provider::cloud::RefreshLlmTokenListener;
2use crate::{
3 provider::{
4 anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
5 copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider,
6 ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
7 },
8 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
9 LanguageModelProviderState,
10};
11use client::{Client, UserStore};
12use collections::BTreeMap;
13use gpui::{AppContext, EventEmitter, Global, Model, ModelContext};
14use std::sync::Arc;
15use ui::Context;
16
17pub fn init(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) {
18 let registry = cx.new_model(|cx| {
19 let mut registry = LanguageModelRegistry::default();
20 register_language_model_providers(&mut registry, user_store, client, cx);
21 registry
22 });
23 cx.set_global(GlobalLanguageModelRegistry(registry));
24}
25
26fn register_language_model_providers(
27 registry: &mut LanguageModelRegistry,
28 user_store: Model<UserStore>,
29 client: Arc<Client>,
30 cx: &mut ModelContext<LanguageModelRegistry>,
31) {
32 use feature_flags::FeatureFlagAppExt;
33
34 RefreshLlmTokenListener::register(client.clone(), cx);
35
36 registry.register_provider(
37 AnthropicLanguageModelProvider::new(client.http_client(), cx),
38 cx,
39 );
40 registry.register_provider(
41 OpenAiLanguageModelProvider::new(client.http_client(), cx),
42 cx,
43 );
44 registry.register_provider(
45 OllamaLanguageModelProvider::new(client.http_client(), cx),
46 cx,
47 );
48 registry.register_provider(
49 GoogleLanguageModelProvider::new(client.http_client(), cx),
50 cx,
51 );
52 registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
53
54 cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
55 let user_store = user_store.clone();
56 let client = client.clone();
57 LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
58 if enabled {
59 registry.register_provider(
60 CloudLanguageModelProvider::new(user_store.clone(), client.clone(), cx),
61 cx,
62 );
63 } else {
64 registry.unregister_provider(
65 LanguageModelProviderId::from(crate::provider::cloud::PROVIDER_ID.to_string()),
66 cx,
67 );
68 }
69 });
70 })
71 .detach();
72}
73
74struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
75
76impl Global for GlobalLanguageModelRegistry {}
77
78#[derive(Default)]
79pub struct LanguageModelRegistry {
80 active_model: Option<ActiveModel>,
81 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
82 inline_alternatives: Vec<Arc<dyn LanguageModel>>,
83}
84
85pub struct ActiveModel {
86 provider: Arc<dyn LanguageModelProvider>,
87 model: Option<Arc<dyn LanguageModel>>,
88}
89
90pub enum Event {
91 ActiveModelChanged,
92 ProviderStateChanged,
93 AddedProvider(LanguageModelProviderId),
94 RemovedProvider(LanguageModelProviderId),
95}
96
97impl EventEmitter<Event> for LanguageModelRegistry {}
98
99impl LanguageModelRegistry {
100 pub fn global(cx: &AppContext) -> Model<Self> {
101 cx.global::<GlobalLanguageModelRegistry>().0.clone()
102 }
103
104 pub fn read_global(cx: &AppContext) -> &Self {
105 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
106 }
107
108 #[cfg(any(test, feature = "test-support"))]
109 pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider {
110 let fake_provider = crate::provider::fake::FakeLanguageModelProvider;
111 let registry = cx.new_model(|cx| {
112 let mut registry = Self::default();
113 registry.register_provider(fake_provider.clone(), cx);
114 let model = fake_provider.provided_models(cx)[0].clone();
115 registry.set_active_model(Some(model), cx);
116 registry
117 });
118 cx.set_global(GlobalLanguageModelRegistry(registry));
119 fake_provider
120 }
121
122 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
123 &mut self,
124 provider: T,
125 cx: &mut ModelContext<Self>,
126 ) {
127 let id = provider.id();
128
129 let subscription = provider.subscribe(cx, |_, cx| {
130 cx.emit(Event::ProviderStateChanged);
131 });
132 if let Some(subscription) = subscription {
133 subscription.detach();
134 }
135
136 self.providers.insert(id.clone(), Arc::new(provider));
137 cx.emit(Event::AddedProvider(id));
138 }
139
140 pub fn unregister_provider(
141 &mut self,
142 id: LanguageModelProviderId,
143 cx: &mut ModelContext<Self>,
144 ) {
145 if self.providers.remove(&id).is_some() {
146 cx.emit(Event::RemovedProvider(id));
147 }
148 }
149
150 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
151 let zed_provider_id = LanguageModelProviderId(crate::provider::cloud::PROVIDER_ID.into());
152 let mut providers = Vec::with_capacity(self.providers.len());
153 if let Some(provider) = self.providers.get(&zed_provider_id) {
154 providers.push(provider.clone());
155 }
156 providers.extend(self.providers.values().filter_map(|p| {
157 if p.id() != zed_provider_id {
158 Some(p.clone())
159 } else {
160 None
161 }
162 }));
163 providers
164 }
165
166 pub fn available_models<'a>(
167 &'a self,
168 cx: &'a AppContext,
169 ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
170 self.providers
171 .values()
172 .flat_map(|provider| provider.provided_models(cx))
173 }
174
175 pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
176 self.providers.get(id).cloned()
177 }
178
179 pub fn select_active_model(
180 &mut self,
181 provider: &LanguageModelProviderId,
182 model_id: &LanguageModelId,
183 cx: &mut ModelContext<Self>,
184 ) {
185 let Some(provider) = self.provider(provider) else {
186 return;
187 };
188
189 let models = provider.provided_models(cx);
190 if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
191 self.set_active_model(Some(model), cx);
192 }
193 }
194
195 pub fn set_active_provider(
196 &mut self,
197 provider: Option<Arc<dyn LanguageModelProvider>>,
198 cx: &mut ModelContext<Self>,
199 ) {
200 self.active_model = provider.map(|provider| ActiveModel {
201 provider,
202 model: None,
203 });
204 cx.emit(Event::ActiveModelChanged);
205 }
206
207 pub fn set_active_model(
208 &mut self,
209 model: Option<Arc<dyn LanguageModel>>,
210 cx: &mut ModelContext<Self>,
211 ) {
212 if let Some(model) = model {
213 let provider_id = model.provider_id();
214 if let Some(provider) = self.providers.get(&provider_id).cloned() {
215 self.active_model = Some(ActiveModel {
216 provider,
217 model: Some(model),
218 });
219 cx.emit(Event::ActiveModelChanged);
220 } else {
221 log::warn!("Active model's provider not found in registry");
222 }
223 } else {
224 self.active_model = None;
225 cx.emit(Event::ActiveModelChanged);
226 }
227 }
228
229 pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
230 Some(self.active_model.as_ref()?.provider.clone())
231 }
232
233 pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
234 self.active_model.as_ref()?.model.clone()
235 }
236
237 /// Selects and sets the inline alternatives for language models based on
238 /// provider name and id.
239 pub fn select_inline_alternative_models(
240 &mut self,
241 alternatives: impl IntoIterator<Item = (LanguageModelProviderId, LanguageModelId)>,
242 cx: &mut ModelContext<Self>,
243 ) {
244 let mut selected_alternatives = Vec::new();
245
246 for (provider_id, model_id) in alternatives {
247 if let Some(provider) = self.providers.get(&provider_id) {
248 if let Some(model) = provider
249 .provided_models(cx)
250 .iter()
251 .find(|m| m.id() == model_id)
252 {
253 selected_alternatives.push(model.clone());
254 }
255 }
256 }
257
258 self.inline_alternatives = selected_alternatives;
259 }
260
261 /// The models to use for inline assists. Returns the union of the active
262 /// model and all inline alternatives. When there are multiple models, the
263 /// user will be able to cycle through results.
264 pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
265 &self.inline_alternatives
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use crate::provider::fake::FakeLanguageModelProvider;
273
274 #[gpui::test]
275 fn test_register_providers(cx: &mut AppContext) {
276 let registry = cx.new_model(|_| LanguageModelRegistry::default());
277
278 registry.update(cx, |registry, cx| {
279 registry.register_provider(FakeLanguageModelProvider, cx);
280 });
281
282 let providers = registry.read(cx).providers();
283 assert_eq!(providers.len(), 1);
284 assert_eq!(providers[0].id(), crate::provider::fake::provider_id());
285
286 registry.update(cx, |registry, cx| {
287 registry.unregister_provider(crate::provider::fake::provider_id(), cx);
288 });
289
290 let providers = registry.read(cx).providers();
291 assert!(providers.is_empty());
292 }
293}