1use crate::{
2 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
3 LanguageModelProviderState,
4};
5use collections::BTreeMap;
6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
7use std::sync::Arc;
8use util::maybe;
9
10pub fn init(cx: &mut App) {
11 let registry = cx.new(|_cx| LanguageModelRegistry::default());
12 cx.set_global(GlobalLanguageModelRegistry(registry));
13}
14
15struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
16
17impl Global for GlobalLanguageModelRegistry {}
18
19#[derive(Default)]
20pub struct LanguageModelRegistry {
21 default_model: Option<ConfiguredModel>,
22 default_fast_model: Option<ConfiguredModel>,
23 inline_assistant_model: Option<ConfiguredModel>,
24 commit_message_model: Option<ConfiguredModel>,
25 thread_summary_model: Option<ConfiguredModel>,
26 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
27 inline_alternatives: Vec<Arc<dyn LanguageModel>>,
28}
29
30pub struct SelectedModel {
31 pub provider: LanguageModelProviderId,
32 pub model: LanguageModelId,
33}
34
35#[derive(Clone)]
36pub struct ConfiguredModel {
37 pub provider: Arc<dyn LanguageModelProvider>,
38 pub model: Arc<dyn LanguageModel>,
39}
40
41impl ConfiguredModel {
42 pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
43 self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
44 }
45}
46
47pub enum Event {
48 DefaultModelChanged,
49 InlineAssistantModelChanged,
50 CommitMessageModelChanged,
51 ThreadSummaryModelChanged,
52 ProviderStateChanged,
53 AddedProvider(LanguageModelProviderId),
54 RemovedProvider(LanguageModelProviderId),
55}
56
57impl EventEmitter<Event> for LanguageModelRegistry {}
58
59impl LanguageModelRegistry {
60 pub fn global(cx: &App) -> Entity<Self> {
61 cx.global::<GlobalLanguageModelRegistry>().0.clone()
62 }
63
64 pub fn read_global(cx: &App) -> &Self {
65 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
66 }
67
68 #[cfg(any(test, feature = "test-support"))]
69 pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
70 let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
71 let registry = cx.new(|cx| {
72 let mut registry = Self::default();
73 registry.register_provider(fake_provider.clone(), cx);
74 let model = fake_provider.provided_models(cx)[0].clone();
75 let configured_model = ConfiguredModel {
76 provider: Arc::new(fake_provider.clone()),
77 model,
78 };
79 registry.set_default_model(Some(configured_model), cx);
80 registry
81 });
82 cx.set_global(GlobalLanguageModelRegistry(registry));
83 fake_provider
84 }
85
86 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
87 &mut self,
88 provider: T,
89 cx: &mut Context<Self>,
90 ) {
91 let id = provider.id();
92
93 let subscription = provider.subscribe(cx, |_, cx| {
94 cx.emit(Event::ProviderStateChanged);
95 });
96 if let Some(subscription) = subscription {
97 subscription.detach();
98 }
99
100 self.providers.insert(id.clone(), Arc::new(provider));
101 cx.emit(Event::AddedProvider(id));
102 }
103
104 pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
105 if self.providers.remove(&id).is_some() {
106 cx.emit(Event::RemovedProvider(id));
107 }
108 }
109
110 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
111 let zed_provider_id = LanguageModelProviderId("zed.dev".into());
112 let mut providers = Vec::with_capacity(self.providers.len());
113 if let Some(provider) = self.providers.get(&zed_provider_id) {
114 providers.push(provider.clone());
115 }
116 providers.extend(self.providers.values().filter_map(|p| {
117 if p.id() != zed_provider_id {
118 Some(p.clone())
119 } else {
120 None
121 }
122 }));
123 providers
124 }
125
126 pub fn available_models<'a>(
127 &'a self,
128 cx: &'a App,
129 ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
130 self.providers
131 .values()
132 .flat_map(|provider| provider.provided_models(cx))
133 }
134
135 pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
136 self.providers.get(id).cloned()
137 }
138
139 pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
140 let configured_model = model.and_then(|model| self.select_model(model, cx));
141 self.set_default_model(configured_model, cx);
142 }
143
144 pub fn select_inline_assistant_model(
145 &mut self,
146 model: Option<&SelectedModel>,
147 cx: &mut Context<Self>,
148 ) {
149 let configured_model = model.and_then(|model| self.select_model(model, cx));
150 self.set_inline_assistant_model(configured_model, cx);
151 }
152
153 pub fn select_commit_message_model(
154 &mut self,
155 model: Option<&SelectedModel>,
156 cx: &mut Context<Self>,
157 ) {
158 let configured_model = model.and_then(|model| self.select_model(model, cx));
159 self.set_commit_message_model(configured_model, cx);
160 }
161
162 pub fn select_thread_summary_model(
163 &mut self,
164 model: Option<&SelectedModel>,
165 cx: &mut Context<Self>,
166 ) {
167 let configured_model = model.and_then(|model| self.select_model(model, cx));
168 self.set_thread_summary_model(configured_model, cx);
169 }
170
171 /// Selects and sets the inline alternatives for language models based on
172 /// provider name and id.
173 pub fn select_inline_alternative_models(
174 &mut self,
175 alternatives: impl IntoIterator<Item = SelectedModel>,
176 cx: &mut Context<Self>,
177 ) {
178 self.inline_alternatives = alternatives
179 .into_iter()
180 .flat_map(|alternative| {
181 self.select_model(&alternative, cx)
182 .map(|configured_model| configured_model.model)
183 })
184 .collect::<Vec<_>>();
185 }
186
187 fn select_model(
188 &mut self,
189 selected_model: &SelectedModel,
190 cx: &mut Context<Self>,
191 ) -> Option<ConfiguredModel> {
192 let provider = self.provider(&selected_model.provider)?;
193 let model = provider
194 .provided_models(cx)
195 .iter()
196 .find(|model| model.id() == selected_model.model)?
197 .clone();
198 Some(ConfiguredModel { provider, model })
199 }
200
201 pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
202 match (self.default_model.as_ref(), model.as_ref()) {
203 (Some(old), Some(new)) if old.is_same_as(new) => {}
204 (None, None) => {}
205 _ => cx.emit(Event::DefaultModelChanged),
206 }
207 self.default_fast_model = maybe!({
208 let provider = &model.as_ref()?.provider;
209 let fast_model = provider.default_fast_model(cx)?;
210 Some(ConfiguredModel {
211 provider: provider.clone(),
212 model: fast_model,
213 })
214 });
215 self.default_model = model;
216 }
217
218 pub fn set_inline_assistant_model(
219 &mut self,
220 model: Option<ConfiguredModel>,
221 cx: &mut Context<Self>,
222 ) {
223 match (self.inline_assistant_model.as_ref(), model.as_ref()) {
224 (Some(old), Some(new)) if old.is_same_as(new) => {}
225 (None, None) => {}
226 _ => cx.emit(Event::InlineAssistantModelChanged),
227 }
228 self.inline_assistant_model = model;
229 }
230
231 pub fn set_commit_message_model(
232 &mut self,
233 model: Option<ConfiguredModel>,
234 cx: &mut Context<Self>,
235 ) {
236 match (self.commit_message_model.as_ref(), model.as_ref()) {
237 (Some(old), Some(new)) if old.is_same_as(new) => {}
238 (None, None) => {}
239 _ => cx.emit(Event::CommitMessageModelChanged),
240 }
241 self.commit_message_model = model;
242 }
243
244 pub fn set_thread_summary_model(
245 &mut self,
246 model: Option<ConfiguredModel>,
247 cx: &mut Context<Self>,
248 ) {
249 match (self.thread_summary_model.as_ref(), model.as_ref()) {
250 (Some(old), Some(new)) if old.is_same_as(new) => {}
251 (None, None) => {}
252 _ => cx.emit(Event::ThreadSummaryModelChanged),
253 }
254 self.thread_summary_model = model;
255 }
256
257 pub fn default_model(&self) -> Option<ConfiguredModel> {
258 #[cfg(debug_assertions)]
259 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
260 return None;
261 }
262
263 self.default_model.clone()
264 }
265
266 pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
267 #[cfg(debug_assertions)]
268 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
269 return None;
270 }
271
272 self.inline_assistant_model
273 .clone()
274 .or_else(|| self.default_model.clone())
275 }
276
277 pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
278 #[cfg(debug_assertions)]
279 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
280 return None;
281 }
282
283 self.commit_message_model
284 .clone()
285 .or_else(|| self.default_model.clone())
286 }
287
288 pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
289 #[cfg(debug_assertions)]
290 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
291 return None;
292 }
293
294 self.thread_summary_model
295 .clone()
296 .or_else(|| self.default_fast_model.clone())
297 .or_else(|| self.default_model.clone())
298 }
299
300 /// The models to use for inline assists. Returns the union of the active
301 /// model and all inline alternatives. When there are multiple models, the
302 /// user will be able to cycle through results.
303 pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
304 &self.inline_alternatives
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311 use crate::fake_provider::FakeLanguageModelProvider;
312
313 #[gpui::test]
314 fn test_register_providers(cx: &mut App) {
315 let registry = cx.new(|_| LanguageModelRegistry::default());
316
317 registry.update(cx, |registry, cx| {
318 registry.register_provider(FakeLanguageModelProvider, cx);
319 });
320
321 let providers = registry.read(cx).providers();
322 assert_eq!(providers.len(), 1);
323 assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
324
325 registry.update(cx, |registry, cx| {
326 registry.unregister_provider(crate::fake_provider::provider_id(), cx);
327 });
328
329 let providers = registry.read(cx).providers();
330 assert!(providers.is_empty());
331 }
332}