completion_provider.rs

  1#[cfg(test)]
  2mod fake;
  3mod open_ai;
  4mod zed;
  5
  6#[cfg(test)]
  7pub use fake::*;
  8pub use open_ai::*;
  9pub use zed::*;
 10
 11use crate::{
 12    assistant_settings::{AssistantProvider, AssistantSettings},
 13    LanguageModel, LanguageModelRequest,
 14};
 15use anyhow::Result;
 16use client::Client;
 17use futures::{future::BoxFuture, stream::BoxStream};
 18use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
 19use settings::{Settings, SettingsStore};
 20use std::sync::Arc;
 21use std::time::Duration;
 22
 23pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 24    let mut settings_version = 0;
 25    let provider = match &AssistantSettings::get_global(cx).provider {
 26        AssistantProvider::ZedDotDev { default_model } => {
 27            CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
 28                default_model.clone(),
 29                client.clone(),
 30                settings_version,
 31                cx,
 32            ))
 33        }
 34        AssistantProvider::OpenAi {
 35            default_model,
 36            api_url,
 37            low_speed_timeout_in_seconds,
 38        } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
 39            default_model.clone(),
 40            api_url.clone(),
 41            client.http_client(),
 42            low_speed_timeout_in_seconds.map(Duration::from_secs),
 43            settings_version,
 44        )),
 45    };
 46    cx.set_global(provider);
 47
 48    cx.observe_global::<SettingsStore>(move |cx| {
 49        settings_version += 1;
 50        cx.update_global::<CompletionProvider, _>(|provider, cx| {
 51            match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
 52                (
 53                    CompletionProvider::OpenAi(provider),
 54                    AssistantProvider::OpenAi {
 55                        default_model,
 56                        api_url,
 57                        low_speed_timeout_in_seconds,
 58                    },
 59                ) => {
 60                    provider.update(
 61                        default_model.clone(),
 62                        api_url.clone(),
 63                        low_speed_timeout_in_seconds.map(Duration::from_secs),
 64                        settings_version,
 65                    );
 66                }
 67                (
 68                    CompletionProvider::ZedDotDev(provider),
 69                    AssistantProvider::ZedDotDev { default_model },
 70                ) => {
 71                    provider.update(default_model.clone(), settings_version);
 72                }
 73                (CompletionProvider::OpenAi(_), AssistantProvider::ZedDotDev { default_model }) => {
 74                    *provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
 75                        default_model.clone(),
 76                        client.clone(),
 77                        settings_version,
 78                        cx,
 79                    ));
 80                }
 81                (
 82                    CompletionProvider::ZedDotDev(_),
 83                    AssistantProvider::OpenAi {
 84                        default_model,
 85                        api_url,
 86                        low_speed_timeout_in_seconds,
 87                    },
 88                ) => {
 89                    *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
 90                        default_model.clone(),
 91                        api_url.clone(),
 92                        client.http_client(),
 93                        low_speed_timeout_in_seconds.map(Duration::from_secs),
 94                        settings_version,
 95                    ));
 96                }
 97                #[cfg(test)]
 98                (CompletionProvider::Fake(_), _) => unimplemented!(),
 99            }
100        })
101    })
102    .detach();
103}
104
105pub enum CompletionProvider {
106    OpenAi(OpenAiCompletionProvider),
107    ZedDotDev(ZedDotDevCompletionProvider),
108    #[cfg(test)]
109    Fake(FakeCompletionProvider),
110}
111
112impl gpui::Global for CompletionProvider {}
113
114impl CompletionProvider {
115    pub fn global(cx: &AppContext) -> &Self {
116        cx.global::<Self>()
117    }
118
119    pub fn settings_version(&self) -> usize {
120        match self {
121            CompletionProvider::OpenAi(provider) => provider.settings_version(),
122            CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
123            #[cfg(test)]
124            CompletionProvider::Fake(_) => unimplemented!(),
125        }
126    }
127
128    pub fn is_authenticated(&self) -> bool {
129        match self {
130            CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
131            CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
132            #[cfg(test)]
133            CompletionProvider::Fake(_) => true,
134        }
135    }
136
137    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
138        match self {
139            CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
140            CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
141            #[cfg(test)]
142            CompletionProvider::Fake(_) => Task::ready(Ok(())),
143        }
144    }
145
146    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
147        match self {
148            CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
149            CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
150            #[cfg(test)]
151            CompletionProvider::Fake(_) => unimplemented!(),
152        }
153    }
154
155    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
156        match self {
157            CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
158            CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
159            #[cfg(test)]
160            CompletionProvider::Fake(_) => Task::ready(Ok(())),
161        }
162    }
163
164    pub fn default_model(&self) -> LanguageModel {
165        match self {
166            CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()),
167            CompletionProvider::ZedDotDev(provider) => {
168                LanguageModel::ZedDotDev(provider.default_model())
169            }
170            #[cfg(test)]
171            CompletionProvider::Fake(_) => unimplemented!(),
172        }
173    }
174
175    pub fn count_tokens(
176        &self,
177        request: LanguageModelRequest,
178        cx: &AppContext,
179    ) -> BoxFuture<'static, Result<usize>> {
180        match self {
181            CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
182            CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
183            #[cfg(test)]
184            CompletionProvider::Fake(_) => unimplemented!(),
185        }
186    }
187
188    pub fn complete(
189        &self,
190        request: LanguageModelRequest,
191    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
192        match self {
193            CompletionProvider::OpenAi(provider) => provider.complete(request),
194            CompletionProvider::ZedDotDev(provider) => provider.complete(request),
195            #[cfg(test)]
196            CompletionProvider::Fake(provider) => provider.complete(),
197        }
198    }
199}