1use crate::{
2 settings::AllLanguageModelSettings, LanguageModel, LanguageModelCacheConfiguration,
3 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
4 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
5};
6use anthropic::AnthropicError;
7use anyhow::{anyhow, Context as _, Result};
8use collections::BTreeMap;
9use editor::{Editor, EditorElement, EditorStyle};
10use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryStreamExt as _};
11use gpui::{
12 AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
13 View, WhiteSpace,
14};
15use http_client::HttpClient;
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use settings::{Settings, SettingsStore};
19use std::{sync::Arc, time::Duration};
20use strum::IntoEnumIterator;
21use theme::ThemeSettings;
22use ui::{prelude::*, Icon, IconName};
23use util::ResultExt;
24
25const PROVIDER_ID: &str = "anthropic";
26const PROVIDER_NAME: &str = "Anthropic";
27
28#[derive(Default, Clone, Debug, PartialEq)]
29pub struct AnthropicSettings {
30 pub api_url: String,
31 pub low_speed_timeout: Option<Duration>,
32 pub available_models: Vec<AvailableModel>,
33 pub needs_setting_migration: bool,
34}
35
36#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
37pub struct AvailableModel {
38 pub name: String,
39 pub max_tokens: usize,
40 pub tool_override: Option<String>,
41 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
42 pub max_output_tokens: Option<u32>,
43}
44
45pub struct AnthropicLanguageModelProvider {
46 http_client: Arc<dyn HttpClient>,
47 state: gpui::Model<State>,
48}
49
50pub struct State {
51 api_key: Option<String>,
52 _subscription: Subscription,
53}
54
55impl State {
56 fn reset_api_key(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
57 let delete_credentials =
58 cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url);
59 cx.spawn(|this, mut cx| async move {
60 delete_credentials.await.ok();
61 this.update(&mut cx, |this, cx| {
62 this.api_key = None;
63 cx.notify();
64 })
65 })
66 }
67
68 fn set_api_key(&mut self, api_key: String, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
69 let write_credentials = cx.write_credentials(
70 AllLanguageModelSettings::get_global(cx)
71 .anthropic
72 .api_url
73 .as_str(),
74 "Bearer",
75 api_key.as_bytes(),
76 );
77 cx.spawn(|this, mut cx| async move {
78 write_credentials.await?;
79
80 this.update(&mut cx, |this, cx| {
81 this.api_key = Some(api_key);
82 cx.notify();
83 })
84 })
85 }
86
87 fn is_authenticated(&self) -> bool {
88 self.api_key.is_some()
89 }
90
91 fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
92 if self.is_authenticated() {
93 Task::ready(Ok(()))
94 } else {
95 let api_url = AllLanguageModelSettings::get_global(cx)
96 .anthropic
97 .api_url
98 .clone();
99
100 cx.spawn(|this, mut cx| async move {
101 let api_key = if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
102 api_key
103 } else {
104 let (_, api_key) = cx
105 .update(|cx| cx.read_credentials(&api_url))?
106 .await?
107 .ok_or_else(|| anyhow!("credentials not found"))?;
108 String::from_utf8(api_key)?
109 };
110
111 this.update(&mut cx, |this, cx| {
112 this.api_key = Some(api_key);
113 cx.notify();
114 })
115 })
116 }
117 }
118}
119
120impl AnthropicLanguageModelProvider {
121 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
122 let state = cx.new_model(|cx| State {
123 api_key: None,
124 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
125 cx.notify();
126 }),
127 });
128
129 Self { http_client, state }
130 }
131}
132
133impl LanguageModelProviderState for AnthropicLanguageModelProvider {
134 type ObservableEntity = State;
135
136 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
137 Some(self.state.clone())
138 }
139}
140
141impl LanguageModelProvider for AnthropicLanguageModelProvider {
142 fn id(&self) -> LanguageModelProviderId {
143 LanguageModelProviderId(PROVIDER_ID.into())
144 }
145
146 fn name(&self) -> LanguageModelProviderName {
147 LanguageModelProviderName(PROVIDER_NAME.into())
148 }
149
150 fn icon(&self) -> IconName {
151 IconName::AiAnthropic
152 }
153
154 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
155 let mut models = BTreeMap::default();
156
157 // Add base models from anthropic::Model::iter()
158 for model in anthropic::Model::iter() {
159 if !matches!(model, anthropic::Model::Custom { .. }) {
160 models.insert(model.id().to_string(), model);
161 }
162 }
163
164 // Override with available models from settings
165 for model in AllLanguageModelSettings::get_global(cx)
166 .anthropic
167 .available_models
168 .iter()
169 {
170 models.insert(
171 model.name.clone(),
172 anthropic::Model::Custom {
173 name: model.name.clone(),
174 max_tokens: model.max_tokens,
175 tool_override: model.tool_override.clone(),
176 cache_configuration: model.cache_configuration.as_ref().map(|config| {
177 anthropic::AnthropicModelCacheConfiguration {
178 max_cache_anchors: config.max_cache_anchors,
179 should_speculate: config.should_speculate,
180 min_total_token: config.min_total_token,
181 }
182 }),
183 max_output_tokens: model.max_output_tokens,
184 },
185 );
186 }
187
188 models
189 .into_values()
190 .map(|model| {
191 Arc::new(AnthropicModel {
192 id: LanguageModelId::from(model.id().to_string()),
193 model,
194 state: self.state.clone(),
195 http_client: self.http_client.clone(),
196 request_limiter: RateLimiter::new(4),
197 }) as Arc<dyn LanguageModel>
198 })
199 .collect()
200 }
201
202 fn is_authenticated(&self, cx: &AppContext) -> bool {
203 self.state.read(cx).is_authenticated()
204 }
205
206 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
207 self.state.update(cx, |state, cx| state.authenticate(cx))
208 }
209
210 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
211 cx.new_view(|cx| ConfigurationView::new(self.state.clone(), cx))
212 .into()
213 }
214
215 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
216 self.state.update(cx, |state, cx| state.reset_api_key(cx))
217 }
218}
219
220pub struct AnthropicModel {
221 id: LanguageModelId,
222 model: anthropic::Model,
223 state: gpui::Model<State>,
224 http_client: Arc<dyn HttpClient>,
225 request_limiter: RateLimiter,
226}
227
228pub fn count_anthropic_tokens(
229 request: LanguageModelRequest,
230 cx: &AppContext,
231) -> BoxFuture<'static, Result<usize>> {
232 cx.background_executor()
233 .spawn(async move {
234 let messages = request.messages;
235 let mut tokens_from_images = 0;
236 let mut string_messages = Vec::with_capacity(messages.len());
237
238 for message in messages {
239 use crate::MessageContent;
240
241 let mut string_contents = String::new();
242
243 for content in message.content {
244 match content {
245 MessageContent::Text(string) => {
246 string_contents.push_str(&string);
247 }
248 MessageContent::Image(image) => {
249 tokens_from_images += image.estimate_tokens();
250 }
251 }
252 }
253
254 if !string_contents.is_empty() {
255 string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
256 role: match message.role {
257 Role::User => "user".into(),
258 Role::Assistant => "assistant".into(),
259 Role::System => "system".into(),
260 },
261 content: Some(string_contents),
262 name: None,
263 function_call: None,
264 });
265 }
266 }
267
268 // Tiktoken doesn't yet support these models, so we manually use the
269 // same tokenizer as GPT-4.
270 tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
271 .map(|tokens| tokens + tokens_from_images)
272 })
273 .boxed()
274}
275
276impl AnthropicModel {
277 fn stream_completion(
278 &self,
279 request: anthropic::Request,
280 cx: &AsyncAppContext,
281 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<anthropic::Event, AnthropicError>>>>
282 {
283 let http_client = self.http_client.clone();
284
285 let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
286 let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
287 (
288 state.api_key.clone(),
289 settings.api_url.clone(),
290 settings.low_speed_timeout,
291 )
292 }) else {
293 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
294 };
295
296 async move {
297 let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
298 let request = anthropic::stream_completion(
299 http_client.as_ref(),
300 &api_url,
301 &api_key,
302 request,
303 low_speed_timeout,
304 );
305 request.await.context("failed to stream completion")
306 }
307 .boxed()
308 }
309}
310
311impl LanguageModel for AnthropicModel {
312 fn id(&self) -> LanguageModelId {
313 self.id.clone()
314 }
315
316 fn name(&self) -> LanguageModelName {
317 LanguageModelName::from(self.model.display_name().to_string())
318 }
319
320 fn provider_id(&self) -> LanguageModelProviderId {
321 LanguageModelProviderId(PROVIDER_ID.into())
322 }
323
324 fn provider_name(&self) -> LanguageModelProviderName {
325 LanguageModelProviderName(PROVIDER_NAME.into())
326 }
327
328 fn telemetry_id(&self) -> String {
329 format!("anthropic/{}", self.model.id())
330 }
331
332 fn max_token_count(&self) -> usize {
333 self.model.max_token_count()
334 }
335
336 fn max_output_tokens(&self) -> Option<u32> {
337 Some(self.model.max_output_tokens())
338 }
339
340 fn count_tokens(
341 &self,
342 request: LanguageModelRequest,
343 cx: &AppContext,
344 ) -> BoxFuture<'static, Result<usize>> {
345 count_anthropic_tokens(request, cx)
346 }
347
348 fn stream_completion(
349 &self,
350 request: LanguageModelRequest,
351 cx: &AsyncAppContext,
352 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
353 let request =
354 request.into_anthropic(self.model.id().into(), self.model.max_output_tokens());
355 let request = self.stream_completion(request, cx);
356 let future = self.request_limiter.stream(async move {
357 let response = request.await.map_err(|err| anyhow!(err))?;
358 Ok(anthropic::extract_text_from_events(response))
359 });
360 async move {
361 Ok(future
362 .await?
363 .map(|result| result.map_err(|err| anyhow!(err)))
364 .boxed())
365 }
366 .boxed()
367 }
368
369 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
370 self.model
371 .cache_configuration()
372 .map(|config| LanguageModelCacheConfiguration {
373 max_cache_anchors: config.max_cache_anchors,
374 should_speculate: config.should_speculate,
375 min_total_token: config.min_total_token,
376 })
377 }
378
379 fn use_any_tool(
380 &self,
381 request: LanguageModelRequest,
382 tool_name: String,
383 tool_description: String,
384 input_schema: serde_json::Value,
385 cx: &AsyncAppContext,
386 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
387 let mut request = request.into_anthropic(
388 self.model.tool_model_id().into(),
389 self.model.max_output_tokens(),
390 );
391 request.tool_choice = Some(anthropic::ToolChoice::Tool {
392 name: tool_name.clone(),
393 });
394 request.tools = vec![anthropic::Tool {
395 name: tool_name.clone(),
396 description: tool_description,
397 input_schema,
398 }];
399
400 let response = self.stream_completion(request, cx);
401 self.request_limiter
402 .run(async move {
403 let response = response.await?;
404 Ok(anthropic::extract_tool_args_from_events(
405 tool_name,
406 Box::pin(response.map_err(|e| anyhow!(e))),
407 )
408 .await?
409 .boxed())
410 })
411 .boxed()
412 }
413}
414
415struct ConfigurationView {
416 api_key_editor: View<Editor>,
417 state: gpui::Model<State>,
418 load_credentials_task: Option<Task<()>>,
419}
420
421impl ConfigurationView {
422 const PLACEHOLDER_TEXT: &'static str = "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
423
424 fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
425 cx.observe(&state, |_, _, cx| {
426 cx.notify();
427 })
428 .detach();
429
430 let load_credentials_task = Some(cx.spawn({
431 let state = state.clone();
432 |this, mut cx| async move {
433 if let Some(task) = state
434 .update(&mut cx, |state, cx| state.authenticate(cx))
435 .log_err()
436 {
437 // We don't log an error, because "not signed in" is also an error.
438 let _ = task.await;
439 }
440 this.update(&mut cx, |this, cx| {
441 this.load_credentials_task = None;
442 cx.notify();
443 })
444 .log_err();
445 }
446 }));
447
448 Self {
449 api_key_editor: cx.new_view(|cx| {
450 let mut editor = Editor::single_line(cx);
451 editor.set_placeholder_text(Self::PLACEHOLDER_TEXT, cx);
452 editor
453 }),
454 state,
455 load_credentials_task,
456 }
457 }
458
459 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
460 let api_key = self.api_key_editor.read(cx).text(cx);
461 if api_key.is_empty() {
462 return;
463 }
464
465 let state = self.state.clone();
466 cx.spawn(|_, mut cx| async move {
467 state
468 .update(&mut cx, |state, cx| state.set_api_key(api_key, cx))?
469 .await
470 })
471 .detach_and_log_err(cx);
472
473 cx.notify();
474 }
475
476 fn reset_api_key(&mut self, cx: &mut ViewContext<Self>) {
477 self.api_key_editor
478 .update(cx, |editor, cx| editor.set_text("", cx));
479
480 let state = self.state.clone();
481 cx.spawn(|_, mut cx| async move {
482 state
483 .update(&mut cx, |state, cx| state.reset_api_key(cx))?
484 .await
485 })
486 .detach_and_log_err(cx);
487
488 cx.notify();
489 }
490
491 fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
492 let settings = ThemeSettings::get_global(cx);
493 let text_style = TextStyle {
494 color: cx.theme().colors().text,
495 font_family: settings.ui_font.family.clone(),
496 font_features: settings.ui_font.features.clone(),
497 font_fallbacks: settings.ui_font.fallbacks.clone(),
498 font_size: rems(0.875).into(),
499 font_weight: settings.ui_font.weight,
500 font_style: FontStyle::Normal,
501 line_height: relative(1.3),
502 background_color: None,
503 underline: None,
504 strikethrough: None,
505 white_space: WhiteSpace::Normal,
506 };
507 EditorElement::new(
508 &self.api_key_editor,
509 EditorStyle {
510 background: cx.theme().colors().editor_background,
511 local_player: cx.theme().players().local(),
512 text: text_style,
513 ..Default::default()
514 },
515 )
516 }
517
518 fn should_render_editor(&self, cx: &mut ViewContext<Self>) -> bool {
519 !self.state.read(cx).is_authenticated()
520 }
521}
522
523impl Render for ConfigurationView {
524 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
525 const INSTRUCTIONS: [&str; 4] = [
526 "To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
527 "You can create an API key at: https://console.anthropic.com/settings/keys",
528 "",
529 "Paste your Anthropic API key below and hit enter to use the assistant:",
530 ];
531
532 if self.load_credentials_task.is_some() {
533 div().child(Label::new("Loading credentials...")).into_any()
534 } else if self.should_render_editor(cx) {
535 v_flex()
536 .size_full()
537 .on_action(cx.listener(Self::save_api_key))
538 .children(
539 INSTRUCTIONS.map(|instruction| Label::new(instruction)),
540 )
541 .child(
542 h_flex()
543 .w_full()
544 .my_2()
545 .px_2()
546 .py_1()
547 .bg(cx.theme().colors().editor_background)
548 .rounded_md()
549 .child(self.render_api_key_editor(cx)),
550 )
551 .child(
552 Label::new(
553 "You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.",
554 )
555 .size(LabelSize::Small),
556 )
557 .into_any()
558 } else {
559 h_flex()
560 .size_full()
561 .justify_between()
562 .child(
563 h_flex()
564 .gap_1()
565 .child(Icon::new(IconName::Check).color(Color::Success))
566 .child(Label::new("API key configured.")),
567 )
568 .child(
569 Button::new("reset-key", "Reset key")
570 .icon(Some(IconName::Trash))
571 .icon_size(IconSize::Small)
572 .icon_position(IconPosition::Start)
573 .on_click(cx.listener(|this, _, cx| this.reset_api_key(cx))),
574 )
575 .into_any()
576 }
577 }
578}