1use anyhow::{Result, anyhow};
2use collections::{BTreeMap, HashMap};
3use futures::Stream;
4use futures::{FutureExt, StreamExt, future, future::BoxFuture};
5use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
6use http_client::HttpClient;
7use language_model::{
8 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
9 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
10 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
11 LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
12 RateLimiter, Role, StopReason, TokenUsage,
13};
14use menu;
15use open_ai::{
16 ImageUrl, Model, OPEN_AI_API_URL, ReasoningEffort, ResponseStreamEvent, stream_completion,
17};
18use settings::{OpenAiAvailableModel as AvailableModel, Settings, SettingsStore};
19use std::pin::Pin;
20use std::str::FromStr as _;
21use std::sync::{Arc, LazyLock};
22use strum::IntoEnumIterator;
23use ui::{List, prelude::*};
24use ui_input::InputField;
25use util::ResultExt;
26use zed_env_vars::{EnvVar, env_var};
27
28use crate::ui::ConfiguredApiCard;
29use crate::{api_key::ApiKeyState, ui::InstructionListItem};
30
31const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
32const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME;
33
34const API_KEY_ENV_VAR_NAME: &str = "OPENAI_API_KEY";
35static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
36
37#[derive(Default, Clone, Debug, PartialEq)]
38pub struct OpenAiSettings {
39 pub api_url: String,
40 pub available_models: Vec<AvailableModel>,
41}
42
43pub struct OpenAiLanguageModelProvider {
44 http_client: Arc<dyn HttpClient>,
45 state: Entity<State>,
46}
47
48pub struct State {
49 api_key_state: ApiKeyState,
50}
51
52impl State {
53 fn is_authenticated(&self) -> bool {
54 self.api_key_state.has_key()
55 }
56
57 fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
58 let api_url = OpenAiLanguageModelProvider::api_url(cx);
59 self.api_key_state
60 .store(api_url, api_key, |this| &mut this.api_key_state, cx)
61 }
62
63 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
64 let api_url = OpenAiLanguageModelProvider::api_url(cx);
65 self.api_key_state.load_if_needed(
66 api_url,
67 &API_KEY_ENV_VAR,
68 |this| &mut this.api_key_state,
69 cx,
70 )
71 }
72}
73
74impl OpenAiLanguageModelProvider {
75 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
76 let state = cx.new(|cx| {
77 cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
78 let api_url = Self::api_url(cx);
79 this.api_key_state.handle_url_change(
80 api_url,
81 &API_KEY_ENV_VAR,
82 |this| &mut this.api_key_state,
83 cx,
84 );
85 cx.notify();
86 })
87 .detach();
88 State {
89 api_key_state: ApiKeyState::new(Self::api_url(cx)),
90 }
91 });
92
93 Self { http_client, state }
94 }
95
96 fn create_language_model(&self, model: open_ai::Model) -> Arc<dyn LanguageModel> {
97 Arc::new(OpenAiLanguageModel {
98 id: LanguageModelId::from(model.id().to_string()),
99 model,
100 state: self.state.clone(),
101 http_client: self.http_client.clone(),
102 request_limiter: RateLimiter::new(4),
103 })
104 }
105
106 fn settings(cx: &App) -> &OpenAiSettings {
107 &crate::AllLanguageModelSettings::get_global(cx).openai
108 }
109
110 fn api_url(cx: &App) -> SharedString {
111 let api_url = &Self::settings(cx).api_url;
112 if api_url.is_empty() {
113 open_ai::OPEN_AI_API_URL.into()
114 } else {
115 SharedString::new(api_url.as_str())
116 }
117 }
118}
119
120impl LanguageModelProviderState for OpenAiLanguageModelProvider {
121 type ObservableEntity = State;
122
123 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
124 Some(self.state.clone())
125 }
126}
127
128impl LanguageModelProvider for OpenAiLanguageModelProvider {
129 fn id(&self) -> LanguageModelProviderId {
130 PROVIDER_ID
131 }
132
133 fn name(&self) -> LanguageModelProviderName {
134 PROVIDER_NAME
135 }
136
137 fn icon(&self) -> IconName {
138 IconName::AiOpenAi
139 }
140
141 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
142 Some(self.create_language_model(open_ai::Model::default()))
143 }
144
145 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
146 Some(self.create_language_model(open_ai::Model::default_fast()))
147 }
148
149 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
150 let mut models = BTreeMap::default();
151
152 // Add base models from open_ai::Model::iter()
153 for model in open_ai::Model::iter() {
154 if !matches!(model, open_ai::Model::Custom { .. }) {
155 models.insert(model.id().to_string(), model);
156 }
157 }
158
159 // Override with available models from settings
160 for model in &OpenAiLanguageModelProvider::settings(cx).available_models {
161 models.insert(
162 model.name.clone(),
163 open_ai::Model::Custom {
164 name: model.name.clone(),
165 display_name: model.display_name.clone(),
166 max_tokens: model.max_tokens,
167 max_output_tokens: model.max_output_tokens,
168 max_completion_tokens: model.max_completion_tokens,
169 reasoning_effort: model.reasoning_effort.clone(),
170 },
171 );
172 }
173
174 models
175 .into_values()
176 .map(|model| self.create_language_model(model))
177 .collect()
178 }
179
180 fn is_authenticated(&self, cx: &App) -> bool {
181 self.state.read(cx).is_authenticated()
182 }
183
184 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
185 self.state.update(cx, |state, cx| state.authenticate(cx))
186 }
187
188 fn configuration_view(
189 &self,
190 _target_agent: language_model::ConfigurationViewTargetAgent,
191 window: &mut Window,
192 cx: &mut App,
193 ) -> AnyView {
194 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
195 .into()
196 }
197
198 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
199 self.state
200 .update(cx, |state, cx| state.set_api_key(None, cx))
201 }
202}
203
204pub struct OpenAiLanguageModel {
205 id: LanguageModelId,
206 model: open_ai::Model,
207 state: Entity<State>,
208 http_client: Arc<dyn HttpClient>,
209 request_limiter: RateLimiter,
210}
211
212impl OpenAiLanguageModel {
213 fn stream_completion(
214 &self,
215 request: open_ai::Request,
216 cx: &AsyncApp,
217 ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
218 {
219 let http_client = self.http_client.clone();
220
221 let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
222 let api_url = OpenAiLanguageModelProvider::api_url(cx);
223 (state.api_key_state.key(&api_url), api_url)
224 }) else {
225 return future::ready(Err(anyhow!("App state dropped"))).boxed();
226 };
227
228 let future = self.request_limiter.stream(async move {
229 let provider = PROVIDER_NAME;
230 let Some(api_key) = api_key else {
231 return Err(LanguageModelCompletionError::NoApiKey { provider });
232 };
233 let request = stream_completion(
234 http_client.as_ref(),
235 provider.0.as_str(),
236 &api_url,
237 &api_key,
238 request,
239 );
240 let response = request.await?;
241 Ok(response)
242 });
243
244 async move { Ok(future.await?.boxed()) }.boxed()
245 }
246}
247
248impl LanguageModel for OpenAiLanguageModel {
249 fn id(&self) -> LanguageModelId {
250 self.id.clone()
251 }
252
253 fn name(&self) -> LanguageModelName {
254 LanguageModelName::from(self.model.display_name().to_string())
255 }
256
257 fn provider_id(&self) -> LanguageModelProviderId {
258 PROVIDER_ID
259 }
260
261 fn provider_name(&self) -> LanguageModelProviderName {
262 PROVIDER_NAME
263 }
264
265 fn supports_tools(&self) -> bool {
266 true
267 }
268
269 fn supports_images(&self) -> bool {
270 use open_ai::Model;
271 match &self.model {
272 Model::FourOmni
273 | Model::FourOmniMini
274 | Model::FourPointOne
275 | Model::FourPointOneMini
276 | Model::FourPointOneNano
277 | Model::Five
278 | Model::FiveMini
279 | Model::FiveNano
280 | Model::FivePointOne
281 | Model::FivePointTwo
282 | Model::O1
283 | Model::O3
284 | Model::O4Mini => true,
285 Model::ThreePointFiveTurbo
286 | Model::Four
287 | Model::FourTurbo
288 | Model::O3Mini
289 | Model::Custom { .. } => false,
290 }
291 }
292
293 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
294 match choice {
295 LanguageModelToolChoice::Auto => true,
296 LanguageModelToolChoice::Any => true,
297 LanguageModelToolChoice::None => true,
298 }
299 }
300
301 fn telemetry_id(&self) -> String {
302 format!("openai/{}", self.model.id())
303 }
304
305 fn max_token_count(&self) -> u64 {
306 self.model.max_token_count()
307 }
308
309 fn max_output_tokens(&self) -> Option<u64> {
310 self.model.max_output_tokens()
311 }
312
313 fn count_tokens(
314 &self,
315 request: LanguageModelRequest,
316 cx: &App,
317 ) -> BoxFuture<'static, Result<u64>> {
318 count_open_ai_tokens(request, self.model.clone(), cx)
319 }
320
321 fn stream_completion(
322 &self,
323 request: LanguageModelRequest,
324 cx: &AsyncApp,
325 ) -> BoxFuture<
326 'static,
327 Result<
328 futures::stream::BoxStream<
329 'static,
330 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
331 >,
332 LanguageModelCompletionError,
333 >,
334 > {
335 let request = into_open_ai(
336 request,
337 self.model.id(),
338 self.model.supports_parallel_tool_calls(),
339 self.model.supports_prompt_cache_key(),
340 self.max_output_tokens(),
341 self.model.reasoning_effort(),
342 );
343 let completions = self.stream_completion(request, cx);
344 async move {
345 let mapper = OpenAiEventMapper::new();
346 Ok(mapper.map_stream(completions.await?).boxed())
347 }
348 .boxed()
349 }
350}
351
352pub fn into_open_ai(
353 request: LanguageModelRequest,
354 model_id: &str,
355 supports_parallel_tool_calls: bool,
356 supports_prompt_cache_key: bool,
357 max_output_tokens: Option<u64>,
358 reasoning_effort: Option<ReasoningEffort>,
359) -> open_ai::Request {
360 let stream = !model_id.starts_with("o1-");
361
362 let mut messages = Vec::new();
363 for message in request.messages {
364 for content in message.content {
365 match content {
366 MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
367 if !text.trim().is_empty() {
368 add_message_content_part(
369 open_ai::MessagePart::Text { text },
370 message.role,
371 &mut messages,
372 );
373 }
374 }
375 MessageContent::RedactedThinking(_) => {}
376 MessageContent::Image(image) => {
377 add_message_content_part(
378 open_ai::MessagePart::Image {
379 image_url: ImageUrl {
380 url: image.to_base64_url(),
381 detail: None,
382 },
383 },
384 message.role,
385 &mut messages,
386 );
387 }
388 MessageContent::ToolUse(tool_use) => {
389 let tool_call = open_ai::ToolCall {
390 id: tool_use.id.to_string(),
391 content: open_ai::ToolCallContent::Function {
392 function: open_ai::FunctionContent {
393 name: tool_use.name.to_string(),
394 arguments: serde_json::to_string(&tool_use.input)
395 .unwrap_or_default(),
396 },
397 },
398 };
399
400 if let Some(open_ai::RequestMessage::Assistant { tool_calls, .. }) =
401 messages.last_mut()
402 {
403 tool_calls.push(tool_call);
404 } else {
405 messages.push(open_ai::RequestMessage::Assistant {
406 content: None,
407 tool_calls: vec![tool_call],
408 });
409 }
410 }
411 MessageContent::ToolResult(tool_result) => {
412 let content = match &tool_result.content {
413 LanguageModelToolResultContent::Text(text) => {
414 vec![open_ai::MessagePart::Text {
415 text: text.to_string(),
416 }]
417 }
418 LanguageModelToolResultContent::Image(image) => {
419 vec![open_ai::MessagePart::Image {
420 image_url: ImageUrl {
421 url: image.to_base64_url(),
422 detail: None,
423 },
424 }]
425 }
426 };
427
428 messages.push(open_ai::RequestMessage::Tool {
429 content: content.into(),
430 tool_call_id: tool_result.tool_use_id.to_string(),
431 });
432 }
433 }
434 }
435 }
436
437 open_ai::Request {
438 model: model_id.into(),
439 messages,
440 stream,
441 stop: request.stop,
442 temperature: request.temperature.or(Some(1.0)),
443 max_completion_tokens: max_output_tokens,
444 parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
445 // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
446 Some(false)
447 } else {
448 None
449 },
450 prompt_cache_key: if supports_prompt_cache_key {
451 request.thread_id
452 } else {
453 None
454 },
455 tools: request
456 .tools
457 .into_iter()
458 .map(|tool| open_ai::ToolDefinition::Function {
459 function: open_ai::FunctionDefinition {
460 name: tool.name,
461 description: Some(tool.description),
462 parameters: Some(tool.input_schema),
463 },
464 })
465 .collect(),
466 tool_choice: request.tool_choice.map(|choice| match choice {
467 LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto,
468 LanguageModelToolChoice::Any => open_ai::ToolChoice::Required,
469 LanguageModelToolChoice::None => open_ai::ToolChoice::None,
470 }),
471 reasoning_effort,
472 }
473}
474
475fn add_message_content_part(
476 new_part: open_ai::MessagePart,
477 role: Role,
478 messages: &mut Vec<open_ai::RequestMessage>,
479) {
480 match (role, messages.last_mut()) {
481 (Role::User, Some(open_ai::RequestMessage::User { content }))
482 | (
483 Role::Assistant,
484 Some(open_ai::RequestMessage::Assistant {
485 content: Some(content),
486 ..
487 }),
488 )
489 | (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => {
490 content.push_part(new_part);
491 }
492 _ => {
493 messages.push(match role {
494 Role::User => open_ai::RequestMessage::User {
495 content: open_ai::MessageContent::from(vec![new_part]),
496 },
497 Role::Assistant => open_ai::RequestMessage::Assistant {
498 content: Some(open_ai::MessageContent::from(vec![new_part])),
499 tool_calls: Vec::new(),
500 },
501 Role::System => open_ai::RequestMessage::System {
502 content: open_ai::MessageContent::from(vec![new_part]),
503 },
504 });
505 }
506 }
507}
508
509pub struct OpenAiEventMapper {
510 tool_calls_by_index: HashMap<usize, RawToolCall>,
511}
512
513impl OpenAiEventMapper {
514 pub fn new() -> Self {
515 Self {
516 tool_calls_by_index: HashMap::default(),
517 }
518 }
519
520 pub fn map_stream(
521 mut self,
522 events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
523 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
524 {
525 events.flat_map(move |event| {
526 futures::stream::iter(match event {
527 Ok(event) => self.map_event(event),
528 Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
529 })
530 })
531 }
532
533 pub fn map_event(
534 &mut self,
535 event: ResponseStreamEvent,
536 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
537 let mut events = Vec::new();
538 if let Some(usage) = event.usage {
539 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
540 input_tokens: usage.prompt_tokens,
541 output_tokens: usage.completion_tokens,
542 cache_creation_input_tokens: 0,
543 cache_read_input_tokens: 0,
544 })));
545 }
546
547 let Some(choice) = event.choices.first() else {
548 return events;
549 };
550
551 if let Some(delta) = choice.delta.as_ref() {
552 if let Some(content) = delta.content.clone() {
553 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
554 }
555
556 if let Some(tool_calls) = delta.tool_calls.as_ref() {
557 for tool_call in tool_calls {
558 let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
559
560 if let Some(tool_id) = tool_call.id.clone() {
561 entry.id = tool_id;
562 }
563
564 if let Some(function) = tool_call.function.as_ref() {
565 if let Some(name) = function.name.clone() {
566 entry.name = name;
567 }
568
569 if let Some(arguments) = function.arguments.clone() {
570 entry.arguments.push_str(&arguments);
571 }
572 }
573 }
574 }
575 }
576
577 match choice.finish_reason.as_deref() {
578 Some("stop") => {
579 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
580 }
581 Some("tool_calls") => {
582 events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
583 match serde_json::Value::from_str(&tool_call.arguments) {
584 Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
585 LanguageModelToolUse {
586 id: tool_call.id.clone().into(),
587 name: tool_call.name.as_str().into(),
588 is_input_complete: true,
589 input,
590 raw_input: tool_call.arguments.clone(),
591 thought_signature: None,
592 },
593 )),
594 Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
595 id: tool_call.id.into(),
596 tool_name: tool_call.name.into(),
597 raw_input: tool_call.arguments.clone().into(),
598 json_parse_error: error.to_string(),
599 }),
600 }
601 }));
602
603 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
604 }
605 Some(stop_reason) => {
606 log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
607 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
608 }
609 None => {}
610 }
611
612 events
613 }
614}
615
616#[derive(Default)]
617struct RawToolCall {
618 id: String,
619 name: String,
620 arguments: String,
621}
622
623pub(crate) fn collect_tiktoken_messages(
624 request: LanguageModelRequest,
625) -> Vec<tiktoken_rs::ChatCompletionRequestMessage> {
626 request
627 .messages
628 .into_iter()
629 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
630 role: match message.role {
631 Role::User => "user".into(),
632 Role::Assistant => "assistant".into(),
633 Role::System => "system".into(),
634 },
635 content: Some(message.string_contents()),
636 name: None,
637 function_call: None,
638 })
639 .collect::<Vec<_>>()
640}
641
642pub fn count_open_ai_tokens(
643 request: LanguageModelRequest,
644 model: Model,
645 cx: &App,
646) -> BoxFuture<'static, Result<u64>> {
647 cx.background_spawn(async move {
648 let messages = collect_tiktoken_messages(request);
649 match model {
650 Model::Custom { max_tokens, .. } => {
651 let model = if max_tokens >= 100_000 {
652 // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
653 "gpt-4o"
654 } else {
655 // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
656 // supported with this tiktoken method
657 "gpt-4"
658 };
659 tiktoken_rs::num_tokens_from_messages(model, &messages)
660 }
661 // Currently supported by tiktoken_rs
662 // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch
663 // arm with an override. We enumerate all supported models here so that we can check if new
664 // models are supported yet or not.
665 Model::ThreePointFiveTurbo
666 | Model::Four
667 | Model::FourTurbo
668 | Model::FourOmni
669 | Model::FourOmniMini
670 | Model::FourPointOne
671 | Model::FourPointOneMini
672 | Model::FourPointOneNano
673 | Model::O1
674 | Model::O3
675 | Model::O3Mini
676 | Model::O4Mini
677 | Model::Five
678 | Model::FiveMini
679 | Model::FiveNano => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
680 // GPT-5.1 and 5.2 don't have dedicated tiktoken support; use gpt-5 tokenizer
681 Model::FivePointOne | Model::FivePointTwo => {
682 tiktoken_rs::num_tokens_from_messages("gpt-5", &messages)
683 }
684 }
685 .map(|tokens| tokens as u64)
686 })
687 .boxed()
688}
689
690struct ConfigurationView {
691 api_key_editor: Entity<InputField>,
692 state: Entity<State>,
693 load_credentials_task: Option<Task<()>>,
694}
695
696impl ConfigurationView {
697 fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
698 let api_key_editor = cx.new(|cx| {
699 InputField::new(
700 window,
701 cx,
702 "sk-000000000000000000000000000000000000000000000000",
703 )
704 });
705
706 cx.observe(&state, |_, _, cx| {
707 cx.notify();
708 })
709 .detach();
710
711 let load_credentials_task = Some(cx.spawn_in(window, {
712 let state = state.clone();
713 async move |this, cx| {
714 if let Some(task) = state
715 .update(cx, |state, cx| state.authenticate(cx))
716 .log_err()
717 {
718 // We don't log an error, because "not signed in" is also an error.
719 let _ = task.await;
720 }
721 this.update(cx, |this, cx| {
722 this.load_credentials_task = None;
723 cx.notify();
724 })
725 .log_err();
726 }
727 }));
728
729 Self {
730 api_key_editor,
731 state,
732 load_credentials_task,
733 }
734 }
735
736 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
737 let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
738 if api_key.is_empty() {
739 return;
740 }
741
742 // url changes can cause the editor to be displayed again
743 self.api_key_editor
744 .update(cx, |editor, cx| editor.set_text("", window, cx));
745
746 let state = self.state.clone();
747 cx.spawn_in(window, async move |_, cx| {
748 state
749 .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
750 .await
751 })
752 .detach_and_log_err(cx);
753 }
754
755 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
756 self.api_key_editor
757 .update(cx, |input, cx| input.set_text("", window, cx));
758
759 let state = self.state.clone();
760 cx.spawn_in(window, async move |_, cx| {
761 state
762 .update(cx, |state, cx| state.set_api_key(None, cx))?
763 .await
764 })
765 .detach_and_log_err(cx);
766 }
767
768 fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
769 !self.state.read(cx).is_authenticated()
770 }
771}
772
773impl Render for ConfigurationView {
774 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
775 let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
776 let configured_card_label = if env_var_set {
777 format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
778 } else {
779 let api_url = OpenAiLanguageModelProvider::api_url(cx);
780 if api_url == OPEN_AI_API_URL {
781 "API key configured".to_string()
782 } else {
783 format!("API key configured for {}", api_url)
784 }
785 };
786
787 let api_key_section = if self.should_render_editor(cx) {
788 v_flex()
789 .on_action(cx.listener(Self::save_api_key))
790 .child(Label::new("To use Zed's agent with OpenAI, you need to add an API key. Follow these steps:"))
791 .child(
792 List::new()
793 .child(InstructionListItem::new(
794 "Create one by visiting",
795 Some("OpenAI's console"),
796 Some("https://platform.openai.com/api-keys"),
797 ))
798 .child(InstructionListItem::text_only(
799 "Ensure your OpenAI account has credits",
800 ))
801 .child(InstructionListItem::text_only(
802 "Paste your API key below and hit enter to start using the assistant",
803 )),
804 )
805 .child(self.api_key_editor.clone())
806 .child(
807 Label::new(format!(
808 "You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
809 ))
810 .size(LabelSize::Small)
811 .color(Color::Muted),
812 )
813 .child(
814 Label::new(
815 "Note that having a subscription for another service like GitHub Copilot won't work.",
816 )
817 .size(LabelSize::Small).color(Color::Muted),
818 )
819 .into_any_element()
820 } else {
821 ConfiguredApiCard::new(configured_card_label)
822 .disabled(env_var_set)
823 .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
824 .when(env_var_set, |this| {
825 this.tooltip_label(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))
826 })
827 .into_any_element()
828 };
829
830 let compatible_api_section = h_flex()
831 .mt_1p5()
832 .gap_0p5()
833 .flex_wrap()
834 .when(self.should_render_editor(cx), |this| {
835 this.pt_1p5()
836 .border_t_1()
837 .border_color(cx.theme().colors().border_variant)
838 })
839 .child(
840 h_flex()
841 .gap_2()
842 .child(
843 Icon::new(IconName::Info)
844 .size(IconSize::XSmall)
845 .color(Color::Muted),
846 )
847 .child(Label::new("Zed also supports OpenAI-compatible models.")),
848 )
849 .child(
850 Button::new("docs", "Learn More")
851 .icon(IconName::ArrowUpRight)
852 .icon_size(IconSize::Small)
853 .icon_color(Color::Muted)
854 .on_click(move |_, _window, cx| {
855 cx.open_url("https://zed.dev/docs/ai/llm-providers#openai-api-compatible")
856 }),
857 );
858
859 if self.load_credentials_task.is_some() {
860 div().child(Label::new("Loading credentials…")).into_any()
861 } else {
862 v_flex()
863 .size_full()
864 .child(api_key_section)
865 .child(compatible_api_section)
866 .into_any()
867 }
868 }
869}
870
871#[cfg(test)]
872mod tests {
873 use gpui::TestAppContext;
874 use language_model::LanguageModelRequestMessage;
875
876 use super::*;
877
878 #[gpui::test]
879 fn tiktoken_rs_support(cx: &TestAppContext) {
880 let request = LanguageModelRequest {
881 thread_id: None,
882 prompt_id: None,
883 intent: None,
884 mode: None,
885 messages: vec![LanguageModelRequestMessage {
886 role: Role::User,
887 content: vec![MessageContent::Text("message".into())],
888 cache: false,
889 reasoning_details: None,
890 }],
891 tools: vec![],
892 tool_choice: None,
893 stop: vec![],
894 temperature: None,
895 thinking_allowed: true,
896 };
897
898 // Validate that all models are supported by tiktoken-rs
899 for model in Model::iter() {
900 let count = cx
901 .executor()
902 .block(count_open_ai_tokens(
903 request.clone(),
904 model,
905 &cx.app.borrow(),
906 ))
907 .unwrap();
908 assert!(count > 0);
909 }
910 }
911}