language_models.rs

 1use std::sync::Arc;
 2
 3use client::{Client, UserStore};
 4use fs::Fs;
 5use gpui::{App, Context, Entity};
 6use language_model::{LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
 7use provider::deepseek::DeepSeekLanguageModelProvider;
 8
 9pub mod provider;
10mod settings;
11
12use crate::provider::anthropic::AnthropicLanguageModelProvider;
13use crate::provider::cloud::CloudLanguageModelProvider;
14use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
15use crate::provider::google::GoogleLanguageModelProvider;
16use crate::provider::lmstudio::LmStudioLanguageModelProvider;
17use crate::provider::mistral::MistralLanguageModelProvider;
18use crate::provider::ollama::OllamaLanguageModelProvider;
19use crate::provider::open_ai::OpenAiLanguageModelProvider;
20pub use crate::settings::*;
21
22pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut App) {
23    crate::settings::init(fs, cx);
24    let registry = LanguageModelRegistry::global(cx);
25    registry.update(cx, |registry, cx| {
26        register_language_model_providers(registry, user_store, client, cx);
27    });
28}
29
30fn register_language_model_providers(
31    registry: &mut LanguageModelRegistry,
32    user_store: Entity<UserStore>,
33    client: Arc<Client>,
34    cx: &mut Context<LanguageModelRegistry>,
35) {
36    use feature_flags::FeatureFlagAppExt;
37
38    registry.register_provider(
39        AnthropicLanguageModelProvider::new(client.http_client(), cx),
40        cx,
41    );
42    registry.register_provider(
43        OpenAiLanguageModelProvider::new(client.http_client(), cx),
44        cx,
45    );
46    registry.register_provider(
47        OllamaLanguageModelProvider::new(client.http_client(), cx),
48        cx,
49    );
50    registry.register_provider(
51        LmStudioLanguageModelProvider::new(client.http_client(), cx),
52        cx,
53    );
54    registry.register_provider(
55        DeepSeekLanguageModelProvider::new(client.http_client(), cx),
56        cx,
57    );
58    registry.register_provider(
59        GoogleLanguageModelProvider::new(client.http_client(), cx),
60        cx,
61    );
62    registry.register_provider(
63        MistralLanguageModelProvider::new(client.http_client(), cx),
64        cx,
65    );
66    registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
67
68    cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
69        let user_store = user_store.clone();
70        let client = client.clone();
71        LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
72            if enabled {
73                registry.register_provider(
74                    CloudLanguageModelProvider::new(user_store.clone(), client.clone(), cx),
75                    cx,
76                );
77            } else {
78                registry.unregister_provider(
79                    LanguageModelProviderId::from(ZED_CLOUD_PROVIDER_ID.to_string()),
80                    cx,
81                );
82            }
83        });
84    })
85    .detach();
86}