language_models.rs

 1use std::sync::Arc;
 2
 3use client::{Client, UserStore};
 4use fs::Fs;
 5use gpui::{AppContext, Model, ModelContext};
 6use language_model::{LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
 7
 8mod logging;
 9pub mod provider;
10mod settings;
11
12use crate::provider::anthropic::AnthropicLanguageModelProvider;
13use crate::provider::cloud::{CloudLanguageModelProvider, RefreshLlmTokenListener};
14use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
15use crate::provider::google::GoogleLanguageModelProvider;
16use crate::provider::ollama::OllamaLanguageModelProvider;
17use crate::provider::open_ai::OpenAiLanguageModelProvider;
18pub use crate::settings::*;
19pub use logging::report_assistant_event;
20
21pub fn init(
22    user_store: Model<UserStore>,
23    client: Arc<Client>,
24    fs: Arc<dyn Fs>,
25    cx: &mut AppContext,
26) {
27    crate::settings::init(fs, cx);
28    let registry = LanguageModelRegistry::global(cx);
29    registry.update(cx, |registry, cx| {
30        register_language_model_providers(registry, user_store, client, cx);
31    });
32}
33
34fn register_language_model_providers(
35    registry: &mut LanguageModelRegistry,
36    user_store: Model<UserStore>,
37    client: Arc<Client>,
38    cx: &mut ModelContext<LanguageModelRegistry>,
39) {
40    use feature_flags::FeatureFlagAppExt;
41
42    RefreshLlmTokenListener::register(client.clone(), cx);
43
44    registry.register_provider(
45        AnthropicLanguageModelProvider::new(client.http_client(), cx),
46        cx,
47    );
48    registry.register_provider(
49        OpenAiLanguageModelProvider::new(client.http_client(), cx),
50        cx,
51    );
52    registry.register_provider(
53        OllamaLanguageModelProvider::new(client.http_client(), cx),
54        cx,
55    );
56    registry.register_provider(
57        GoogleLanguageModelProvider::new(client.http_client(), cx),
58        cx,
59    );
60    registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
61
62    cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
63        let user_store = user_store.clone();
64        let client = client.clone();
65        LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
66            if enabled {
67                registry.register_provider(
68                    CloudLanguageModelProvider::new(user_store.clone(), client.clone(), cx),
69                    cx,
70                );
71            } else {
72                registry.unregister_provider(
73                    LanguageModelProviderId::from(ZED_CLOUD_PROVIDER_ID.to_string()),
74                    cx,
75                );
76            }
77        });
78    })
79    .detach();
80}