language_models.rs

 1use std::sync::Arc;
 2
 3use client::{Client, UserStore};
 4use fs::Fs;
 5use gpui::{AppContext, Model, ModelContext};
 6use language_model::{LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
 7use provider::deepseek::DeepSeekLanguageModelProvider;
 8
 9mod logging;
10pub mod provider;
11mod settings;
12
13use crate::provider::anthropic::AnthropicLanguageModelProvider;
14use crate::provider::cloud::CloudLanguageModelProvider;
15pub use crate::provider::cloud::LlmApiToken;
16pub use crate::provider::cloud::RefreshLlmTokenListener;
17use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
18use crate::provider::google::GoogleLanguageModelProvider;
19use crate::provider::lmstudio::LmStudioLanguageModelProvider;
20use crate::provider::ollama::OllamaLanguageModelProvider;
21use crate::provider::open_ai::OpenAiLanguageModelProvider;
22pub use crate::settings::*;
23pub use logging::report_assistant_event;
24
25pub fn init(
26    user_store: Model<UserStore>,
27    client: Arc<Client>,
28    fs: Arc<dyn Fs>,
29    cx: &mut AppContext,
30) {
31    crate::settings::init(fs, cx);
32    let registry = LanguageModelRegistry::global(cx);
33    registry.update(cx, |registry, cx| {
34        register_language_model_providers(registry, user_store, client, cx);
35    });
36}
37
38fn register_language_model_providers(
39    registry: &mut LanguageModelRegistry,
40    user_store: Model<UserStore>,
41    client: Arc<Client>,
42    cx: &mut ModelContext<LanguageModelRegistry>,
43) {
44    use feature_flags::FeatureFlagAppExt;
45
46    RefreshLlmTokenListener::register(client.clone(), cx);
47
48    registry.register_provider(
49        AnthropicLanguageModelProvider::new(client.http_client(), cx),
50        cx,
51    );
52    registry.register_provider(
53        OpenAiLanguageModelProvider::new(client.http_client(), cx),
54        cx,
55    );
56    registry.register_provider(
57        OllamaLanguageModelProvider::new(client.http_client(), cx),
58        cx,
59    );
60    registry.register_provider(
61        LmStudioLanguageModelProvider::new(client.http_client(), cx),
62        cx,
63    );
64    registry.register_provider(
65        DeepSeekLanguageModelProvider::new(client.http_client(), cx),
66        cx,
67    );
68    registry.register_provider(
69        GoogleLanguageModelProvider::new(client.http_client(), cx),
70        cx,
71    );
72    registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
73
74    cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
75        let user_store = user_store.clone();
76        let client = client.clone();
77        LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
78            if enabled {
79                registry.register_provider(
80                    CloudLanguageModelProvider::new(user_store.clone(), client.clone(), cx),
81                    cx,
82                );
83            } else {
84                registry.unregister_provider(
85                    LanguageModelProviderId::from(ZED_CLOUD_PROVIDER_ID.to_string()),
86                    cx,
87                );
88            }
89        });
90    })
91    .detach();
92}