1use crate::{
2 settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
3 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
4 LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
5};
6use anyhow::{anyhow, Context as _, Result};
7use collections::BTreeMap;
8use editor::{Editor, EditorElement, EditorStyle};
9use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
10use gpui::{
11 AnyView, AppContext, AsyncAppContext, FocusHandle, FocusableView, FontStyle, ModelContext,
12 Subscription, Task, TextStyle, View, WhiteSpace,
13};
14use http_client::HttpClient;
15use schemars::JsonSchema;
16use serde::{Deserialize, Serialize};
17use settings::{Settings, SettingsStore};
18use std::{sync::Arc, time::Duration};
19use strum::IntoEnumIterator;
20use theme::ThemeSettings;
21use ui::{prelude::*, Indicator};
22
23const PROVIDER_ID: &str = "anthropic";
24const PROVIDER_NAME: &str = "Anthropic";
25
26#[derive(Default, Clone, Debug, PartialEq)]
27pub struct AnthropicSettings {
28 pub api_url: String,
29 pub low_speed_timeout: Option<Duration>,
30 pub available_models: Vec<AvailableModel>,
31 pub needs_setting_migration: bool,
32}
33
34#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
35pub struct AvailableModel {
36 pub name: String,
37 pub max_tokens: usize,
38 pub tool_override: Option<String>,
39}
40
41pub struct AnthropicLanguageModelProvider {
42 http_client: Arc<dyn HttpClient>,
43 state: gpui::Model<State>,
44}
45
46pub struct State {
47 api_key: Option<String>,
48 _subscription: Subscription,
49}
50
51impl State {
52 fn reset_api_key(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
53 let delete_credentials =
54 cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url);
55 cx.spawn(|this, mut cx| async move {
56 delete_credentials.await.ok();
57 this.update(&mut cx, |this, cx| {
58 this.api_key = None;
59 cx.notify();
60 })
61 })
62 }
63
64 fn set_api_key(&mut self, api_key: String, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
65 let write_credentials = cx.write_credentials(
66 AllLanguageModelSettings::get_global(cx)
67 .anthropic
68 .api_url
69 .as_str(),
70 "Bearer",
71 api_key.as_bytes(),
72 );
73 cx.spawn(|this, mut cx| async move {
74 write_credentials.await?;
75
76 this.update(&mut cx, |this, cx| {
77 this.api_key = Some(api_key);
78 cx.notify();
79 })
80 })
81 }
82
83 fn is_authenticated(&self) -> bool {
84 self.api_key.is_some()
85 }
86}
87
88impl AnthropicLanguageModelProvider {
89 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
90 let state = cx.new_model(|cx| State {
91 api_key: None,
92 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
93 cx.notify();
94 }),
95 });
96
97 Self { http_client, state }
98 }
99}
100
101impl LanguageModelProviderState for AnthropicLanguageModelProvider {
102 type ObservableEntity = State;
103
104 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
105 Some(self.state.clone())
106 }
107}
108
109impl LanguageModelProvider for AnthropicLanguageModelProvider {
110 fn id(&self) -> LanguageModelProviderId {
111 LanguageModelProviderId(PROVIDER_ID.into())
112 }
113
114 fn name(&self) -> LanguageModelProviderName {
115 LanguageModelProviderName(PROVIDER_NAME.into())
116 }
117
118 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
119 let mut models = BTreeMap::default();
120
121 // Add base models from anthropic::Model::iter()
122 for model in anthropic::Model::iter() {
123 if !matches!(model, anthropic::Model::Custom { .. }) {
124 models.insert(model.id().to_string(), model);
125 }
126 }
127
128 // Override with available models from settings
129 for model in AllLanguageModelSettings::get_global(cx)
130 .anthropic
131 .available_models
132 .iter()
133 {
134 models.insert(
135 model.name.clone(),
136 anthropic::Model::Custom {
137 name: model.name.clone(),
138 max_tokens: model.max_tokens,
139 tool_override: model.tool_override.clone(),
140 },
141 );
142 }
143
144 models
145 .into_values()
146 .map(|model| {
147 Arc::new(AnthropicModel {
148 id: LanguageModelId::from(model.id().to_string()),
149 model,
150 state: self.state.clone(),
151 http_client: self.http_client.clone(),
152 request_limiter: RateLimiter::new(4),
153 }) as Arc<dyn LanguageModel>
154 })
155 .collect()
156 }
157
158 fn is_authenticated(&self, cx: &AppContext) -> bool {
159 self.state.read(cx).is_authenticated()
160 }
161
162 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
163 if self.is_authenticated(cx) {
164 Task::ready(Ok(()))
165 } else {
166 let api_url = AllLanguageModelSettings::get_global(cx)
167 .anthropic
168 .api_url
169 .clone();
170 let state = self.state.clone();
171 cx.spawn(|mut cx| async move {
172 let api_key = if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
173 api_key
174 } else {
175 let (_, api_key) = cx
176 .update(|cx| cx.read_credentials(&api_url))?
177 .await?
178 .ok_or_else(|| anyhow!("credentials not found"))?;
179 String::from_utf8(api_key)?
180 };
181
182 state.update(&mut cx, |this, cx| {
183 this.api_key = Some(api_key);
184 cx.notify();
185 })
186 })
187 }
188 }
189
190 fn configuration_view(&self, cx: &mut WindowContext) -> (AnyView, Option<FocusHandle>) {
191 let view = cx.new_view(|cx| ConfigurationView::new(self.state.clone(), cx));
192 let focus_handle = view.focus_handle(cx);
193 (view.into(), Some(focus_handle))
194 }
195
196 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
197 self.state.update(cx, |state, cx| state.reset_api_key(cx))
198 }
199}
200
201pub struct AnthropicModel {
202 id: LanguageModelId,
203 model: anthropic::Model,
204 state: gpui::Model<State>,
205 http_client: Arc<dyn HttpClient>,
206 request_limiter: RateLimiter,
207}
208
209pub fn count_anthropic_tokens(
210 request: LanguageModelRequest,
211 cx: &AppContext,
212) -> BoxFuture<'static, Result<usize>> {
213 cx.background_executor()
214 .spawn(async move {
215 let messages = request
216 .messages
217 .into_iter()
218 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
219 role: match message.role {
220 Role::User => "user".into(),
221 Role::Assistant => "assistant".into(),
222 Role::System => "system".into(),
223 },
224 content: Some(message.content),
225 name: None,
226 function_call: None,
227 })
228 .collect::<Vec<_>>();
229
230 // Tiktoken doesn't yet support these models, so we manually use the
231 // same tokenizer as GPT-4.
232 tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
233 })
234 .boxed()
235}
236
237impl AnthropicModel {
238 fn request_completion(
239 &self,
240 request: anthropic::Request,
241 cx: &AsyncAppContext,
242 ) -> BoxFuture<'static, Result<anthropic::Response>> {
243 let http_client = self.http_client.clone();
244
245 let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
246 let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
247 (state.api_key.clone(), settings.api_url.clone())
248 }) else {
249 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
250 };
251
252 async move {
253 let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
254 anthropic::complete(http_client.as_ref(), &api_url, &api_key, request).await
255 }
256 .boxed()
257 }
258
259 fn stream_completion(
260 &self,
261 request: anthropic::Request,
262 cx: &AsyncAppContext,
263 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<anthropic::Event>>>> {
264 let http_client = self.http_client.clone();
265
266 let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
267 let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
268 (
269 state.api_key.clone(),
270 settings.api_url.clone(),
271 settings.low_speed_timeout,
272 )
273 }) else {
274 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
275 };
276
277 async move {
278 let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
279 let request = anthropic::stream_completion(
280 http_client.as_ref(),
281 &api_url,
282 &api_key,
283 request,
284 low_speed_timeout,
285 );
286 request.await
287 }
288 .boxed()
289 }
290}
291
292impl LanguageModel for AnthropicModel {
293 fn id(&self) -> LanguageModelId {
294 self.id.clone()
295 }
296
297 fn name(&self) -> LanguageModelName {
298 LanguageModelName::from(self.model.display_name().to_string())
299 }
300
301 fn provider_id(&self) -> LanguageModelProviderId {
302 LanguageModelProviderId(PROVIDER_ID.into())
303 }
304
305 fn provider_name(&self) -> LanguageModelProviderName {
306 LanguageModelProviderName(PROVIDER_NAME.into())
307 }
308
309 fn telemetry_id(&self) -> String {
310 format!("anthropic/{}", self.model.id())
311 }
312
313 fn max_token_count(&self) -> usize {
314 self.model.max_token_count()
315 }
316
317 fn count_tokens(
318 &self,
319 request: LanguageModelRequest,
320 cx: &AppContext,
321 ) -> BoxFuture<'static, Result<usize>> {
322 count_anthropic_tokens(request, cx)
323 }
324
325 fn stream_completion(
326 &self,
327 request: LanguageModelRequest,
328 cx: &AsyncAppContext,
329 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
330 let request = request.into_anthropic(self.model.id().into());
331 let request = self.stream_completion(request, cx);
332 let future = self.request_limiter.stream(async move {
333 let response = request.await?;
334 Ok(anthropic::extract_text_from_events(response))
335 });
336 async move { Ok(future.await?.boxed()) }.boxed()
337 }
338
339 fn use_any_tool(
340 &self,
341 request: LanguageModelRequest,
342 tool_name: String,
343 tool_description: String,
344 input_schema: serde_json::Value,
345 cx: &AsyncAppContext,
346 ) -> BoxFuture<'static, Result<serde_json::Value>> {
347 let mut request = request.into_anthropic(self.model.tool_model_id().into());
348 request.tool_choice = Some(anthropic::ToolChoice::Tool {
349 name: tool_name.clone(),
350 });
351 request.tools = vec![anthropic::Tool {
352 name: tool_name.clone(),
353 description: tool_description,
354 input_schema,
355 }];
356
357 let response = self.request_completion(request, cx);
358 self.request_limiter
359 .run(async move {
360 let response = response.await?;
361 response
362 .content
363 .into_iter()
364 .find_map(|content| {
365 if let anthropic::Content::ToolUse { name, input, .. } = content {
366 if name == tool_name {
367 Some(input)
368 } else {
369 None
370 }
371 } else {
372 None
373 }
374 })
375 .context("tool not used")
376 })
377 .boxed()
378 }
379}
380
381struct ConfigurationView {
382 api_key_editor: View<Editor>,
383 state: gpui::Model<State>,
384}
385
386impl FocusableView for ConfigurationView {
387 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
388 self.api_key_editor.read(cx).focus_handle(cx)
389 }
390}
391
392impl ConfigurationView {
393 fn new(state: gpui::Model<State>, cx: &mut WindowContext) -> Self {
394 Self {
395 api_key_editor: cx.new_view(|cx| {
396 let mut editor = Editor::single_line(cx);
397 editor.set_placeholder_text(
398 "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
399 cx,
400 );
401 editor
402 }),
403 state,
404 }
405 }
406
407 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
408 let api_key = self.api_key_editor.read(cx).text(cx);
409 if api_key.is_empty() {
410 return;
411 }
412
413 self.state
414 .update(cx, |state, cx| state.set_api_key(api_key, cx))
415 .detach_and_log_err(cx);
416 }
417
418 fn reset_api_key(&mut self, cx: &mut ViewContext<Self>) {
419 self.api_key_editor
420 .update(cx, |editor, cx| editor.set_text("", cx));
421 self.state
422 .update(cx, |state, cx| state.reset_api_key(cx))
423 .detach_and_log_err(cx);
424 }
425
426 fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
427 let settings = ThemeSettings::get_global(cx);
428 let text_style = TextStyle {
429 color: cx.theme().colors().text,
430 font_family: settings.ui_font.family.clone(),
431 font_features: settings.ui_font.features.clone(),
432 font_fallbacks: settings.ui_font.fallbacks.clone(),
433 font_size: rems(0.875).into(),
434 font_weight: settings.ui_font.weight,
435 font_style: FontStyle::Normal,
436 line_height: relative(1.3),
437 background_color: None,
438 underline: None,
439 strikethrough: None,
440 white_space: WhiteSpace::Normal,
441 };
442 EditorElement::new(
443 &self.api_key_editor,
444 EditorStyle {
445 background: cx.theme().colors().editor_background,
446 local_player: cx.theme().players().local(),
447 text: text_style,
448 ..Default::default()
449 },
450 )
451 }
452}
453
454impl Render for ConfigurationView {
455 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
456 const INSTRUCTIONS: [&str; 4] = [
457 "To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
458 "You can create an API key at: https://console.anthropic.com/settings/keys",
459 "",
460 "Paste your Anthropic API key below and hit enter to use the assistant:",
461 ];
462
463 if self.state.read(cx).is_authenticated() {
464 h_flex()
465 .size_full()
466 .justify_between()
467 .child(
468 h_flex()
469 .gap_2()
470 .child(Indicator::dot().color(Color::Success))
471 .child(Label::new("API Key configured").size(LabelSize::Small)),
472 )
473 .child(
474 Button::new("reset-key", "Reset key")
475 .icon(Some(IconName::Trash))
476 .icon_size(IconSize::Small)
477 .icon_position(IconPosition::Start)
478 .on_click(cx.listener(|this, _, cx| this.reset_api_key(cx))),
479 )
480 .into_any()
481 } else {
482 v_flex()
483 .size_full()
484 .on_action(cx.listener(Self::save_api_key))
485 .children(
486 INSTRUCTIONS.map(|instruction| Label::new(instruction)),
487 )
488 .child(
489 h_flex()
490 .w_full()
491 .my_2()
492 .px_2()
493 .py_1()
494 .bg(cx.theme().colors().editor_background)
495 .rounded_md()
496 .child(self.render_api_key_editor(cx)),
497 )
498 .child(
499 Label::new(
500 "You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.",
501 )
502 .size(LabelSize::Small),
503 )
504 .into_any()
505 }
506 }
507}