1mod anthropic;
2#[cfg(test)]
3mod fake;
4mod open_ai;
5mod zed;
6
7pub use anthropic::*;
8#[cfg(test)]
9pub use fake::*;
10pub use open_ai::*;
11pub use zed::*;
12
13use crate::{
14 assistant_settings::{AssistantProvider, AssistantSettings},
15 LanguageModel, LanguageModelRequest,
16};
17use anyhow::Result;
18use client::Client;
19use futures::{future::BoxFuture, stream::BoxStream};
20use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
21use settings::{Settings, SettingsStore};
22use std::sync::Arc;
23use std::time::Duration;
24
25pub fn init(client: Arc<Client>, cx: &mut AppContext) {
26 let mut settings_version = 0;
27 let provider = match &AssistantSettings::get_global(cx).provider {
28 AssistantProvider::ZedDotDev { default_model } => {
29 CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
30 default_model.clone(),
31 client.clone(),
32 settings_version,
33 cx,
34 ))
35 }
36 AssistantProvider::OpenAi {
37 default_model,
38 api_url,
39 low_speed_timeout_in_seconds,
40 } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
41 default_model.clone(),
42 api_url.clone(),
43 client.http_client(),
44 low_speed_timeout_in_seconds.map(Duration::from_secs),
45 settings_version,
46 )),
47 AssistantProvider::Anthropic {
48 default_model,
49 api_url,
50 low_speed_timeout_in_seconds,
51 } => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
52 default_model.clone(),
53 api_url.clone(),
54 client.http_client(),
55 low_speed_timeout_in_seconds.map(Duration::from_secs),
56 settings_version,
57 )),
58 };
59 cx.set_global(provider);
60
61 cx.observe_global::<SettingsStore>(move |cx| {
62 settings_version += 1;
63 cx.update_global::<CompletionProvider, _>(|provider, cx| {
64 match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
65 (
66 CompletionProvider::OpenAi(provider),
67 AssistantProvider::OpenAi {
68 default_model,
69 api_url,
70 low_speed_timeout_in_seconds,
71 },
72 ) => {
73 provider.update(
74 default_model.clone(),
75 api_url.clone(),
76 low_speed_timeout_in_seconds.map(Duration::from_secs),
77 settings_version,
78 );
79 }
80 (
81 CompletionProvider::Anthropic(provider),
82 AssistantProvider::Anthropic {
83 default_model,
84 api_url,
85 low_speed_timeout_in_seconds,
86 },
87 ) => {
88 provider.update(
89 default_model.clone(),
90 api_url.clone(),
91 low_speed_timeout_in_seconds.map(Duration::from_secs),
92 settings_version,
93 );
94 }
95 (
96 CompletionProvider::ZedDotDev(provider),
97 AssistantProvider::ZedDotDev { default_model },
98 ) => {
99 provider.update(default_model.clone(), settings_version);
100 }
101 (_, AssistantProvider::ZedDotDev { default_model }) => {
102 *provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
103 default_model.clone(),
104 client.clone(),
105 settings_version,
106 cx,
107 ));
108 }
109 (
110 _,
111 AssistantProvider::OpenAi {
112 default_model,
113 api_url,
114 low_speed_timeout_in_seconds,
115 },
116 ) => {
117 *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
118 default_model.clone(),
119 api_url.clone(),
120 client.http_client(),
121 low_speed_timeout_in_seconds.map(Duration::from_secs),
122 settings_version,
123 ));
124 }
125 (
126 _,
127 AssistantProvider::Anthropic {
128 default_model,
129 api_url,
130 low_speed_timeout_in_seconds,
131 },
132 ) => {
133 *provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
134 default_model.clone(),
135 api_url.clone(),
136 client.http_client(),
137 low_speed_timeout_in_seconds.map(Duration::from_secs),
138 settings_version,
139 ));
140 }
141 }
142 })
143 })
144 .detach();
145}
146
147pub enum CompletionProvider {
148 OpenAi(OpenAiCompletionProvider),
149 Anthropic(AnthropicCompletionProvider),
150 ZedDotDev(ZedDotDevCompletionProvider),
151 #[cfg(test)]
152 Fake(FakeCompletionProvider),
153}
154
155impl gpui::Global for CompletionProvider {}
156
157impl CompletionProvider {
158 pub fn global(cx: &AppContext) -> &Self {
159 cx.global::<Self>()
160 }
161
162 pub fn settings_version(&self) -> usize {
163 match self {
164 CompletionProvider::OpenAi(provider) => provider.settings_version(),
165 CompletionProvider::Anthropic(provider) => provider.settings_version(),
166 CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
167 #[cfg(test)]
168 CompletionProvider::Fake(_) => unimplemented!(),
169 }
170 }
171
172 pub fn is_authenticated(&self) -> bool {
173 match self {
174 CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
175 CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
176 CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
177 #[cfg(test)]
178 CompletionProvider::Fake(_) => true,
179 }
180 }
181
182 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
183 match self {
184 CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
185 CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
186 CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
187 #[cfg(test)]
188 CompletionProvider::Fake(_) => Task::ready(Ok(())),
189 }
190 }
191
192 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
193 match self {
194 CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
195 CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
196 CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
197 #[cfg(test)]
198 CompletionProvider::Fake(_) => unimplemented!(),
199 }
200 }
201
202 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
203 match self {
204 CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
205 CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
206 CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
207 #[cfg(test)]
208 CompletionProvider::Fake(_) => Task::ready(Ok(())),
209 }
210 }
211
212 pub fn default_model(&self) -> LanguageModel {
213 match self {
214 CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()),
215 CompletionProvider::Anthropic(provider) => {
216 LanguageModel::Anthropic(provider.default_model())
217 }
218 CompletionProvider::ZedDotDev(provider) => {
219 LanguageModel::ZedDotDev(provider.default_model())
220 }
221 #[cfg(test)]
222 CompletionProvider::Fake(_) => unimplemented!(),
223 }
224 }
225
226 pub fn count_tokens(
227 &self,
228 request: LanguageModelRequest,
229 cx: &AppContext,
230 ) -> BoxFuture<'static, Result<usize>> {
231 match self {
232 CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
233 CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
234 CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
235 #[cfg(test)]
236 CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))),
237 }
238 }
239
240 pub fn complete(
241 &self,
242 request: LanguageModelRequest,
243 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
244 match self {
245 CompletionProvider::OpenAi(provider) => provider.complete(request),
246 CompletionProvider::Anthropic(provider) => provider.complete(request),
247 CompletionProvider::ZedDotDev(provider) => provider.complete(request),
248 #[cfg(test)]
249 CompletionProvider::Fake(provider) => provider.complete(),
250 }
251 }
252}