language_models.rs

 1use std::sync::Arc;
 2
 3use client::{Client, UserStore};
 4use fs::Fs;
 5use gpui::{App, Context, Entity};
 6use language_model::{LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
 7use provider::deepseek::DeepSeekLanguageModelProvider;
 8
 9pub mod provider;
10mod settings;
11
12use crate::provider::anthropic::AnthropicLanguageModelProvider;
13use crate::provider::cloud::CloudLanguageModelProvider;
14pub use crate::provider::cloud::LlmApiToken;
15pub use crate::provider::cloud::RefreshLlmTokenListener;
16use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
17use crate::provider::google::GoogleLanguageModelProvider;
18use crate::provider::lmstudio::LmStudioLanguageModelProvider;
19use crate::provider::mistral::MistralLanguageModelProvider;
20use crate::provider::ollama::OllamaLanguageModelProvider;
21use crate::provider::open_ai::OpenAiLanguageModelProvider;
22pub use crate::settings::*;
23
24pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut App) {
25    crate::settings::init(fs, cx);
26    let registry = LanguageModelRegistry::global(cx);
27    registry.update(cx, |registry, cx| {
28        register_language_model_providers(registry, user_store, client, cx);
29    });
30}
31
32fn register_language_model_providers(
33    registry: &mut LanguageModelRegistry,
34    user_store: Entity<UserStore>,
35    client: Arc<Client>,
36    cx: &mut Context<LanguageModelRegistry>,
37) {
38    use feature_flags::FeatureFlagAppExt;
39
40    RefreshLlmTokenListener::register(client.clone(), cx);
41
42    registry.register_provider(
43        AnthropicLanguageModelProvider::new(client.http_client(), cx),
44        cx,
45    );
46    registry.register_provider(
47        OpenAiLanguageModelProvider::new(client.http_client(), cx),
48        cx,
49    );
50    registry.register_provider(
51        OllamaLanguageModelProvider::new(client.http_client(), cx),
52        cx,
53    );
54    registry.register_provider(
55        LmStudioLanguageModelProvider::new(client.http_client(), cx),
56        cx,
57    );
58    registry.register_provider(
59        DeepSeekLanguageModelProvider::new(client.http_client(), cx),
60        cx,
61    );
62    registry.register_provider(
63        GoogleLanguageModelProvider::new(client.http_client(), cx),
64        cx,
65    );
66    registry.register_provider(
67        MistralLanguageModelProvider::new(client.http_client(), cx),
68        cx,
69    );
70    registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
71
72    cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
73        let user_store = user_store.clone();
74        let client = client.clone();
75        LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
76            if enabled {
77                registry.register_provider(
78                    CloudLanguageModelProvider::new(user_store.clone(), client.clone(), cx),
79                    cx,
80                );
81            } else {
82                registry.unregister_provider(
83                    LanguageModelProviderId::from(ZED_CLOUD_PROVIDER_ID.to_string()),
84                    cx,
85                );
86            }
87        });
88    })
89    .detach();
90}