1mod anthropic;
2mod cloud;
3#[cfg(test)]
4mod fake;
5mod open_ai;
6
7pub use anthropic::*;
8pub use cloud::*;
9#[cfg(test)]
10pub use fake::*;
11pub use open_ai::*;
12
13use crate::{
14 assistant_settings::{AssistantProvider, AssistantSettings},
15 LanguageModel, LanguageModelRequest,
16};
17use anyhow::Result;
18use client::Client;
19use futures::{future::BoxFuture, stream::BoxStream};
20use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
21use settings::{Settings, SettingsStore};
22use std::sync::Arc;
23use std::time::Duration;
24
25pub fn init(client: Arc<Client>, cx: &mut AppContext) {
26 let mut settings_version = 0;
27 let provider = match &AssistantSettings::get_global(cx).provider {
28 AssistantProvider::ZedDotDev { model } => CompletionProvider::Cloud(
29 CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
30 ),
31 AssistantProvider::OpenAi {
32 model,
33 api_url,
34 low_speed_timeout_in_seconds,
35 } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
36 model.clone(),
37 api_url.clone(),
38 client.http_client(),
39 low_speed_timeout_in_seconds.map(Duration::from_secs),
40 settings_version,
41 )),
42 AssistantProvider::Anthropic {
43 model,
44 api_url,
45 low_speed_timeout_in_seconds,
46 } => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
47 model.clone(),
48 api_url.clone(),
49 client.http_client(),
50 low_speed_timeout_in_seconds.map(Duration::from_secs),
51 settings_version,
52 )),
53 };
54 cx.set_global(provider);
55
56 cx.observe_global::<SettingsStore>(move |cx| {
57 settings_version += 1;
58 cx.update_global::<CompletionProvider, _>(|provider, cx| {
59 match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
60 (
61 CompletionProvider::OpenAi(provider),
62 AssistantProvider::OpenAi {
63 model,
64 api_url,
65 low_speed_timeout_in_seconds,
66 },
67 ) => {
68 provider.update(
69 model.clone(),
70 api_url.clone(),
71 low_speed_timeout_in_seconds.map(Duration::from_secs),
72 settings_version,
73 );
74 }
75 (
76 CompletionProvider::Anthropic(provider),
77 AssistantProvider::Anthropic {
78 model,
79 api_url,
80 low_speed_timeout_in_seconds,
81 },
82 ) => {
83 provider.update(
84 model.clone(),
85 api_url.clone(),
86 low_speed_timeout_in_seconds.map(Duration::from_secs),
87 settings_version,
88 );
89 }
90 (CompletionProvider::Cloud(provider), AssistantProvider::ZedDotDev { model }) => {
91 provider.update(model.clone(), settings_version);
92 }
93 (_, AssistantProvider::ZedDotDev { model }) => {
94 *provider = CompletionProvider::Cloud(CloudCompletionProvider::new(
95 model.clone(),
96 client.clone(),
97 settings_version,
98 cx,
99 ));
100 }
101 (
102 _,
103 AssistantProvider::OpenAi {
104 model,
105 api_url,
106 low_speed_timeout_in_seconds,
107 },
108 ) => {
109 *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
110 model.clone(),
111 api_url.clone(),
112 client.http_client(),
113 low_speed_timeout_in_seconds.map(Duration::from_secs),
114 settings_version,
115 ));
116 }
117 (
118 _,
119 AssistantProvider::Anthropic {
120 model,
121 api_url,
122 low_speed_timeout_in_seconds,
123 },
124 ) => {
125 *provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
126 model.clone(),
127 api_url.clone(),
128 client.http_client(),
129 low_speed_timeout_in_seconds.map(Duration::from_secs),
130 settings_version,
131 ));
132 }
133 }
134 })
135 })
136 .detach();
137}
138
139pub enum CompletionProvider {
140 OpenAi(OpenAiCompletionProvider),
141 Anthropic(AnthropicCompletionProvider),
142 Cloud(CloudCompletionProvider),
143 #[cfg(test)]
144 Fake(FakeCompletionProvider),
145}
146
147impl gpui::Global for CompletionProvider {}
148
149impl CompletionProvider {
150 pub fn global(cx: &AppContext) -> &Self {
151 cx.global::<Self>()
152 }
153
154 pub fn available_models(&self) -> Vec<LanguageModel> {
155 match self {
156 CompletionProvider::OpenAi(provider) => provider
157 .available_models()
158 .map(LanguageModel::OpenAi)
159 .collect(),
160 CompletionProvider::Anthropic(provider) => provider
161 .available_models()
162 .map(LanguageModel::Anthropic)
163 .collect(),
164 CompletionProvider::Cloud(provider) => provider
165 .available_models()
166 .map(LanguageModel::Cloud)
167 .collect(),
168 #[cfg(test)]
169 CompletionProvider::Fake(_) => unimplemented!(),
170 }
171 }
172
173 pub fn settings_version(&self) -> usize {
174 match self {
175 CompletionProvider::OpenAi(provider) => provider.settings_version(),
176 CompletionProvider::Anthropic(provider) => provider.settings_version(),
177 CompletionProvider::Cloud(provider) => provider.settings_version(),
178 #[cfg(test)]
179 CompletionProvider::Fake(_) => unimplemented!(),
180 }
181 }
182
183 pub fn is_authenticated(&self) -> bool {
184 match self {
185 CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
186 CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
187 CompletionProvider::Cloud(provider) => provider.is_authenticated(),
188 #[cfg(test)]
189 CompletionProvider::Fake(_) => true,
190 }
191 }
192
193 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
194 match self {
195 CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
196 CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
197 CompletionProvider::Cloud(provider) => provider.authenticate(cx),
198 #[cfg(test)]
199 CompletionProvider::Fake(_) => Task::ready(Ok(())),
200 }
201 }
202
203 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
204 match self {
205 CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
206 CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
207 CompletionProvider::Cloud(provider) => provider.authentication_prompt(cx),
208 #[cfg(test)]
209 CompletionProvider::Fake(_) => unimplemented!(),
210 }
211 }
212
213 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
214 match self {
215 CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
216 CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
217 CompletionProvider::Cloud(_) => Task::ready(Ok(())),
218 #[cfg(test)]
219 CompletionProvider::Fake(_) => Task::ready(Ok(())),
220 }
221 }
222
223 pub fn model(&self) -> LanguageModel {
224 match self {
225 CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
226 CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
227 CompletionProvider::Cloud(provider) => LanguageModel::Cloud(provider.model()),
228 #[cfg(test)]
229 CompletionProvider::Fake(_) => LanguageModel::default(),
230 }
231 }
232
233 pub fn count_tokens(
234 &self,
235 request: LanguageModelRequest,
236 cx: &AppContext,
237 ) -> BoxFuture<'static, Result<usize>> {
238 match self {
239 CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
240 CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
241 CompletionProvider::Cloud(provider) => provider.count_tokens(request, cx),
242 #[cfg(test)]
243 CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))),
244 }
245 }
246
247 pub fn complete(
248 &self,
249 request: LanguageModelRequest,
250 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
251 match self {
252 CompletionProvider::OpenAi(provider) => provider.complete(request),
253 CompletionProvider::Anthropic(provider) => provider.complete(request),
254 CompletionProvider::Cloud(provider) => provider.complete(request),
255 #[cfg(test)]
256 CompletionProvider::Fake(provider) => provider.complete(),
257 }
258 }
259}