1use anyhow::{Result, anyhow};
2use collections::{BTreeMap, HashMap};
3use futures::Stream;
4use futures::{FutureExt, StreamExt, future, future::BoxFuture};
5use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
6use http_client::HttpClient;
7use language_model::{
8 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
9 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
10 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
11 LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
12 RateLimiter, Role, StopReason, TokenUsage,
13};
14use menu;
15use open_ai::{
16 ImageUrl, Model, OPEN_AI_API_URL, ReasoningEffort, ResponseStreamEvent, stream_completion,
17};
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20use settings::{Settings, SettingsStore};
21use std::pin::Pin;
22use std::str::FromStr as _;
23use std::sync::{Arc, LazyLock};
24use strum::IntoEnumIterator;
25use ui::{ElevationIndex, List, Tooltip, prelude::*};
26use ui_input::SingleLineInput;
27use util::{ResultExt, truncate_and_trailoff};
28use zed_env_vars::{EnvVar, env_var};
29
30use crate::{api_key::ApiKeyState, ui::InstructionListItem};
31
32const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
33const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME;
34
35const API_KEY_ENV_VAR_NAME: &str = "OPENAI_API_KEY";
36static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
37
38#[derive(Default, Clone, Debug, PartialEq)]
39pub struct OpenAiSettings {
40 pub api_url: String,
41 pub available_models: Vec<AvailableModel>,
42}
43
44#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
45pub struct AvailableModel {
46 pub name: String,
47 pub display_name: Option<String>,
48 pub max_tokens: u64,
49 pub max_output_tokens: Option<u64>,
50 pub max_completion_tokens: Option<u64>,
51 pub reasoning_effort: Option<ReasoningEffort>,
52}
53
54pub struct OpenAiLanguageModelProvider {
55 http_client: Arc<dyn HttpClient>,
56 state: gpui::Entity<State>,
57}
58
59pub struct State {
60 api_key_state: ApiKeyState,
61}
62
63impl State {
64 fn is_authenticated(&self) -> bool {
65 self.api_key_state.has_key()
66 }
67
68 fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
69 let api_url = OpenAiLanguageModelProvider::api_url(cx);
70 self.api_key_state
71 .store(api_url, api_key, |this| &mut this.api_key_state, cx)
72 }
73
74 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
75 let api_url = OpenAiLanguageModelProvider::api_url(cx);
76 self.api_key_state.load_if_needed(
77 api_url,
78 &API_KEY_ENV_VAR,
79 |this| &mut this.api_key_state,
80 cx,
81 )
82 }
83}
84
85impl OpenAiLanguageModelProvider {
86 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
87 let state = cx.new(|cx| {
88 cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
89 let api_url = Self::api_url(cx);
90 this.api_key_state.handle_url_change(
91 api_url,
92 &API_KEY_ENV_VAR,
93 |this| &mut this.api_key_state,
94 cx,
95 );
96 cx.notify();
97 })
98 .detach();
99 State {
100 api_key_state: ApiKeyState::new(Self::api_url(cx)),
101 }
102 });
103
104 Self { http_client, state }
105 }
106
107 fn create_language_model(&self, model: open_ai::Model) -> Arc<dyn LanguageModel> {
108 Arc::new(OpenAiLanguageModel {
109 id: LanguageModelId::from(model.id().to_string()),
110 model,
111 state: self.state.clone(),
112 http_client: self.http_client.clone(),
113 request_limiter: RateLimiter::new(4),
114 })
115 }
116
117 fn settings(cx: &App) -> &OpenAiSettings {
118 &crate::AllLanguageModelSettings::get_global(cx).openai
119 }
120
121 fn api_url(cx: &App) -> SharedString {
122 let api_url = &Self::settings(cx).api_url;
123 if api_url.is_empty() {
124 open_ai::OPEN_AI_API_URL.into()
125 } else {
126 SharedString::new(api_url.as_str())
127 }
128 }
129}
130
131impl LanguageModelProviderState for OpenAiLanguageModelProvider {
132 type ObservableEntity = State;
133
134 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
135 Some(self.state.clone())
136 }
137}
138
139impl LanguageModelProvider for OpenAiLanguageModelProvider {
140 fn id(&self) -> LanguageModelProviderId {
141 PROVIDER_ID
142 }
143
144 fn name(&self) -> LanguageModelProviderName {
145 PROVIDER_NAME
146 }
147
148 fn icon(&self) -> IconName {
149 IconName::AiOpenAi
150 }
151
152 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
153 Some(self.create_language_model(open_ai::Model::default()))
154 }
155
156 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
157 Some(self.create_language_model(open_ai::Model::default_fast()))
158 }
159
160 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
161 let mut models = BTreeMap::default();
162
163 // Add base models from open_ai::Model::iter()
164 for model in open_ai::Model::iter() {
165 if !matches!(model, open_ai::Model::Custom { .. }) {
166 models.insert(model.id().to_string(), model);
167 }
168 }
169
170 // Override with available models from settings
171 for model in &OpenAiLanguageModelProvider::settings(cx).available_models {
172 models.insert(
173 model.name.clone(),
174 open_ai::Model::Custom {
175 name: model.name.clone(),
176 display_name: model.display_name.clone(),
177 max_tokens: model.max_tokens,
178 max_output_tokens: model.max_output_tokens,
179 max_completion_tokens: model.max_completion_tokens,
180 reasoning_effort: model.reasoning_effort.clone(),
181 },
182 );
183 }
184
185 models
186 .into_values()
187 .map(|model| self.create_language_model(model))
188 .collect()
189 }
190
191 fn is_authenticated(&self, cx: &App) -> bool {
192 self.state.read(cx).is_authenticated()
193 }
194
195 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
196 self.state.update(cx, |state, cx| state.authenticate(cx))
197 }
198
199 fn configuration_view(
200 &self,
201 _target_agent: language_model::ConfigurationViewTargetAgent,
202 window: &mut Window,
203 cx: &mut App,
204 ) -> AnyView {
205 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
206 .into()
207 }
208
209 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
210 self.state
211 .update(cx, |state, cx| state.set_api_key(None, cx))
212 }
213}
214
215pub struct OpenAiLanguageModel {
216 id: LanguageModelId,
217 model: open_ai::Model,
218 state: gpui::Entity<State>,
219 http_client: Arc<dyn HttpClient>,
220 request_limiter: RateLimiter,
221}
222
223impl OpenAiLanguageModel {
224 fn stream_completion(
225 &self,
226 request: open_ai::Request,
227 cx: &AsyncApp,
228 ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
229 {
230 let http_client = self.http_client.clone();
231
232 let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
233 let api_url = OpenAiLanguageModelProvider::api_url(cx);
234 (state.api_key_state.key(&api_url), api_url)
235 }) else {
236 return future::ready(Err(anyhow!("App state dropped"))).boxed();
237 };
238
239 let future = self.request_limiter.stream(async move {
240 let Some(api_key) = api_key else {
241 return Err(LanguageModelCompletionError::NoApiKey {
242 provider: PROVIDER_NAME,
243 });
244 };
245 let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
246 let response = request.await?;
247 Ok(response)
248 });
249
250 async move { Ok(future.await?.boxed()) }.boxed()
251 }
252}
253
254impl LanguageModel for OpenAiLanguageModel {
255 fn id(&self) -> LanguageModelId {
256 self.id.clone()
257 }
258
259 fn name(&self) -> LanguageModelName {
260 LanguageModelName::from(self.model.display_name().to_string())
261 }
262
263 fn provider_id(&self) -> LanguageModelProviderId {
264 PROVIDER_ID
265 }
266
267 fn provider_name(&self) -> LanguageModelProviderName {
268 PROVIDER_NAME
269 }
270
271 fn supports_tools(&self) -> bool {
272 true
273 }
274
275 fn supports_images(&self) -> bool {
276 use open_ai::Model;
277 match &self.model {
278 Model::FourOmni
279 | Model::FourOmniMini
280 | Model::FourPointOne
281 | Model::FourPointOneMini
282 | Model::FourPointOneNano
283 | Model::Five
284 | Model::FiveMini
285 | Model::FiveNano
286 | Model::O1
287 | Model::O3
288 | Model::O4Mini => true,
289 Model::ThreePointFiveTurbo
290 | Model::Four
291 | Model::FourTurbo
292 | Model::O3Mini
293 | Model::Custom { .. } => false,
294 }
295 }
296
297 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
298 match choice {
299 LanguageModelToolChoice::Auto => true,
300 LanguageModelToolChoice::Any => true,
301 LanguageModelToolChoice::None => true,
302 }
303 }
304
305 fn telemetry_id(&self) -> String {
306 format!("openai/{}", self.model.id())
307 }
308
309 fn max_token_count(&self) -> u64 {
310 self.model.max_token_count()
311 }
312
313 fn max_output_tokens(&self) -> Option<u64> {
314 self.model.max_output_tokens()
315 }
316
317 fn count_tokens(
318 &self,
319 request: LanguageModelRequest,
320 cx: &App,
321 ) -> BoxFuture<'static, Result<u64>> {
322 count_open_ai_tokens(request, self.model.clone(), cx)
323 }
324
325 fn stream_completion(
326 &self,
327 request: LanguageModelRequest,
328 cx: &AsyncApp,
329 ) -> BoxFuture<
330 'static,
331 Result<
332 futures::stream::BoxStream<
333 'static,
334 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
335 >,
336 LanguageModelCompletionError,
337 >,
338 > {
339 let request = into_open_ai(
340 request,
341 self.model.id(),
342 self.model.supports_parallel_tool_calls(),
343 self.model.supports_prompt_cache_key(),
344 self.max_output_tokens(),
345 self.model.reasoning_effort(),
346 );
347 let completions = self.stream_completion(request, cx);
348 async move {
349 let mapper = OpenAiEventMapper::new();
350 Ok(mapper.map_stream(completions.await?).boxed())
351 }
352 .boxed()
353 }
354}
355
356pub fn into_open_ai(
357 request: LanguageModelRequest,
358 model_id: &str,
359 supports_parallel_tool_calls: bool,
360 supports_prompt_cache_key: bool,
361 max_output_tokens: Option<u64>,
362 reasoning_effort: Option<ReasoningEffort>,
363) -> open_ai::Request {
364 let stream = !model_id.starts_with("o1-");
365
366 let mut messages = Vec::new();
367 for message in request.messages {
368 for content in message.content {
369 match content {
370 MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
371 add_message_content_part(
372 open_ai::MessagePart::Text { text },
373 message.role,
374 &mut messages,
375 )
376 }
377 MessageContent::RedactedThinking(_) => {}
378 MessageContent::Image(image) => {
379 add_message_content_part(
380 open_ai::MessagePart::Image {
381 image_url: ImageUrl {
382 url: image.to_base64_url(),
383 detail: None,
384 },
385 },
386 message.role,
387 &mut messages,
388 );
389 }
390 MessageContent::ToolUse(tool_use) => {
391 let tool_call = open_ai::ToolCall {
392 id: tool_use.id.to_string(),
393 content: open_ai::ToolCallContent::Function {
394 function: open_ai::FunctionContent {
395 name: tool_use.name.to_string(),
396 arguments: serde_json::to_string(&tool_use.input)
397 .unwrap_or_default(),
398 },
399 },
400 };
401
402 if let Some(open_ai::RequestMessage::Assistant { tool_calls, .. }) =
403 messages.last_mut()
404 {
405 tool_calls.push(tool_call);
406 } else {
407 messages.push(open_ai::RequestMessage::Assistant {
408 content: None,
409 tool_calls: vec![tool_call],
410 });
411 }
412 }
413 MessageContent::ToolResult(tool_result) => {
414 let content = match &tool_result.content {
415 LanguageModelToolResultContent::Text(text) => {
416 vec![open_ai::MessagePart::Text {
417 text: text.to_string(),
418 }]
419 }
420 LanguageModelToolResultContent::Image(image) => {
421 vec![open_ai::MessagePart::Image {
422 image_url: ImageUrl {
423 url: image.to_base64_url(),
424 detail: None,
425 },
426 }]
427 }
428 };
429
430 messages.push(open_ai::RequestMessage::Tool {
431 content: content.into(),
432 tool_call_id: tool_result.tool_use_id.to_string(),
433 });
434 }
435 }
436 }
437 }
438
439 open_ai::Request {
440 model: model_id.into(),
441 messages,
442 stream,
443 stop: request.stop,
444 temperature: request.temperature.unwrap_or(1.0),
445 max_completion_tokens: max_output_tokens,
446 parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
447 // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
448 Some(false)
449 } else {
450 None
451 },
452 prompt_cache_key: if supports_prompt_cache_key {
453 request.thread_id
454 } else {
455 None
456 },
457 tools: request
458 .tools
459 .into_iter()
460 .map(|tool| open_ai::ToolDefinition::Function {
461 function: open_ai::FunctionDefinition {
462 name: tool.name,
463 description: Some(tool.description),
464 parameters: Some(tool.input_schema),
465 },
466 })
467 .collect(),
468 tool_choice: request.tool_choice.map(|choice| match choice {
469 LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto,
470 LanguageModelToolChoice::Any => open_ai::ToolChoice::Required,
471 LanguageModelToolChoice::None => open_ai::ToolChoice::None,
472 }),
473 reasoning_effort,
474 }
475}
476
477fn add_message_content_part(
478 new_part: open_ai::MessagePart,
479 role: Role,
480 messages: &mut Vec<open_ai::RequestMessage>,
481) {
482 match (role, messages.last_mut()) {
483 (Role::User, Some(open_ai::RequestMessage::User { content }))
484 | (
485 Role::Assistant,
486 Some(open_ai::RequestMessage::Assistant {
487 content: Some(content),
488 ..
489 }),
490 )
491 | (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => {
492 content.push_part(new_part);
493 }
494 _ => {
495 messages.push(match role {
496 Role::User => open_ai::RequestMessage::User {
497 content: open_ai::MessageContent::from(vec![new_part]),
498 },
499 Role::Assistant => open_ai::RequestMessage::Assistant {
500 content: Some(open_ai::MessageContent::from(vec![new_part])),
501 tool_calls: Vec::new(),
502 },
503 Role::System => open_ai::RequestMessage::System {
504 content: open_ai::MessageContent::from(vec![new_part]),
505 },
506 });
507 }
508 }
509}
510
511pub struct OpenAiEventMapper {
512 tool_calls_by_index: HashMap<usize, RawToolCall>,
513}
514
515impl OpenAiEventMapper {
516 pub fn new() -> Self {
517 Self {
518 tool_calls_by_index: HashMap::default(),
519 }
520 }
521
522 pub fn map_stream(
523 mut self,
524 events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
525 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
526 {
527 events.flat_map(move |event| {
528 futures::stream::iter(match event {
529 Ok(event) => self.map_event(event),
530 Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
531 })
532 })
533 }
534
535 pub fn map_event(
536 &mut self,
537 event: ResponseStreamEvent,
538 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
539 let mut events = Vec::new();
540 if let Some(usage) = event.usage {
541 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
542 input_tokens: usage.prompt_tokens,
543 output_tokens: usage.completion_tokens,
544 cache_creation_input_tokens: 0,
545 cache_read_input_tokens: 0,
546 })));
547 }
548
549 let Some(choice) = event.choices.first() else {
550 return events;
551 };
552
553 if let Some(content) = choice.delta.content.clone() {
554 if !content.is_empty() {
555 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
556 }
557 }
558
559 if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
560 for tool_call in tool_calls {
561 let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
562
563 if let Some(tool_id) = tool_call.id.clone() {
564 entry.id = tool_id;
565 }
566
567 if let Some(function) = tool_call.function.as_ref() {
568 if let Some(name) = function.name.clone() {
569 entry.name = name;
570 }
571
572 if let Some(arguments) = function.arguments.clone() {
573 entry.arguments.push_str(&arguments);
574 }
575 }
576 }
577 }
578
579 match choice.finish_reason.as_deref() {
580 Some("stop") => {
581 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
582 }
583 Some("tool_calls") => {
584 events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
585 match serde_json::Value::from_str(&tool_call.arguments) {
586 Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
587 LanguageModelToolUse {
588 id: tool_call.id.clone().into(),
589 name: tool_call.name.as_str().into(),
590 is_input_complete: true,
591 input,
592 raw_input: tool_call.arguments.clone(),
593 },
594 )),
595 Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
596 id: tool_call.id.into(),
597 tool_name: tool_call.name.into(),
598 raw_input: tool_call.arguments.clone().into(),
599 json_parse_error: error.to_string(),
600 }),
601 }
602 }));
603
604 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
605 }
606 Some(stop_reason) => {
607 log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
608 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
609 }
610 None => {}
611 }
612
613 events
614 }
615}
616
617#[derive(Default)]
618struct RawToolCall {
619 id: String,
620 name: String,
621 arguments: String,
622}
623
624pub(crate) fn collect_tiktoken_messages(
625 request: LanguageModelRequest,
626) -> Vec<tiktoken_rs::ChatCompletionRequestMessage> {
627 request
628 .messages
629 .into_iter()
630 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
631 role: match message.role {
632 Role::User => "user".into(),
633 Role::Assistant => "assistant".into(),
634 Role::System => "system".into(),
635 },
636 content: Some(message.string_contents()),
637 name: None,
638 function_call: None,
639 })
640 .collect::<Vec<_>>()
641}
642
643pub fn count_open_ai_tokens(
644 request: LanguageModelRequest,
645 model: Model,
646 cx: &App,
647) -> BoxFuture<'static, Result<u64>> {
648 cx.background_spawn(async move {
649 let messages = collect_tiktoken_messages(request);
650
651 match model {
652 Model::Custom { max_tokens, .. } => {
653 let model = if max_tokens >= 100_000 {
654 // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
655 "gpt-4o"
656 } else {
657 // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
658 // supported with this tiktoken method
659 "gpt-4"
660 };
661 tiktoken_rs::num_tokens_from_messages(model, &messages)
662 }
663 // Currently supported by tiktoken_rs
664 // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch
665 // arm with an override. We enumerate all supported models here so that we can check if new
666 // models are supported yet or not.
667 Model::ThreePointFiveTurbo
668 | Model::Four
669 | Model::FourTurbo
670 | Model::FourOmni
671 | Model::FourOmniMini
672 | Model::FourPointOne
673 | Model::FourPointOneMini
674 | Model::FourPointOneNano
675 | Model::O1
676 | Model::O3
677 | Model::O3Mini
678 | Model::O4Mini => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
679 // GPT-5 models don't have tiktoken support yet; fall back on gpt-4o tokenizer
680 Model::Five | Model::FiveMini | Model::FiveNano => {
681 tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages)
682 }
683 }
684 .map(|tokens| tokens as u64)
685 })
686 .boxed()
687}
688
689struct ConfigurationView {
690 api_key_editor: Entity<SingleLineInput>,
691 state: gpui::Entity<State>,
692 load_credentials_task: Option<Task<()>>,
693}
694
695impl ConfigurationView {
696 fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
697 let api_key_editor = cx.new(|cx| {
698 SingleLineInput::new(
699 window,
700 cx,
701 "sk-000000000000000000000000000000000000000000000000",
702 )
703 });
704
705 cx.observe(&state, |_, _, cx| {
706 cx.notify();
707 })
708 .detach();
709
710 let load_credentials_task = Some(cx.spawn_in(window, {
711 let state = state.clone();
712 async move |this, cx| {
713 if let Some(task) = state
714 .update(cx, |state, cx| state.authenticate(cx))
715 .log_err()
716 {
717 // We don't log an error, because "not signed in" is also an error.
718 let _ = task.await;
719 }
720 this.update(cx, |this, cx| {
721 this.load_credentials_task = None;
722 cx.notify();
723 })
724 .log_err();
725 }
726 }));
727
728 Self {
729 api_key_editor,
730 state,
731 load_credentials_task,
732 }
733 }
734
735 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
736 let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
737 if api_key.is_empty() {
738 return;
739 }
740
741 // url changes can cause the editor to be displayed again
742 self.api_key_editor
743 .update(cx, |editor, cx| editor.set_text("", window, cx));
744
745 let state = self.state.clone();
746 cx.spawn_in(window, async move |_, cx| {
747 state
748 .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
749 .await
750 })
751 .detach_and_log_err(cx);
752 }
753
754 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
755 self.api_key_editor
756 .update(cx, |input, cx| input.set_text("", window, cx));
757
758 let state = self.state.clone();
759 cx.spawn_in(window, async move |_, cx| {
760 state
761 .update(cx, |state, cx| state.set_api_key(None, cx))?
762 .await
763 })
764 .detach_and_log_err(cx);
765 }
766
767 fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
768 !self.state.read(cx).is_authenticated()
769 }
770}
771
772impl Render for ConfigurationView {
773 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
774 let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
775
776 let api_key_section = if self.should_render_editor(cx) {
777 v_flex()
778 .on_action(cx.listener(Self::save_api_key))
779 .child(Label::new("To use Zed's agent with OpenAI, you need to add an API key. Follow these steps:"))
780 .child(
781 List::new()
782 .child(InstructionListItem::new(
783 "Create one by visiting",
784 Some("OpenAI's console"),
785 Some("https://platform.openai.com/api-keys"),
786 ))
787 .child(InstructionListItem::text_only(
788 "Ensure your OpenAI account has credits",
789 ))
790 .child(InstructionListItem::text_only(
791 "Paste your API key below and hit enter to start using the assistant",
792 )),
793 )
794 .child(self.api_key_editor.clone())
795 .child(
796 Label::new(format!(
797 "You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
798 ))
799 .size(LabelSize::Small)
800 .color(Color::Muted),
801 )
802 .child(
803 Label::new(
804 "Note that having a subscription for another service like GitHub Copilot won't work.",
805 )
806 .size(LabelSize::Small).color(Color::Muted),
807 )
808 .into_any()
809 } else {
810 h_flex()
811 .mt_1()
812 .p_1()
813 .justify_between()
814 .rounded_md()
815 .border_1()
816 .border_color(cx.theme().colors().border)
817 .bg(cx.theme().colors().background)
818 .child(
819 h_flex()
820 .gap_1()
821 .child(Icon::new(IconName::Check).color(Color::Success))
822 .child(Label::new(if env_var_set {
823 format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
824 } else {
825 let api_url = OpenAiLanguageModelProvider::api_url(cx);
826 if api_url == OPEN_AI_API_URL {
827 "API key configured".to_string()
828 } else {
829 format!("API key configured for {}", truncate_and_trailoff(&api_url, 32))
830 }
831 })),
832 )
833 .child(
834 Button::new("reset-api-key", "Reset API Key")
835 .label_size(LabelSize::Small)
836 .icon(IconName::Undo)
837 .icon_size(IconSize::Small)
838 .icon_position(IconPosition::Start)
839 .layer(ElevationIndex::ModalSurface)
840 .when(env_var_set, |this| {
841 this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
842 })
843 .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
844 )
845 .into_any()
846 };
847
848 let compatible_api_section = h_flex()
849 .mt_1p5()
850 .gap_0p5()
851 .flex_wrap()
852 .when(self.should_render_editor(cx), |this| {
853 this.pt_1p5()
854 .border_t_1()
855 .border_color(cx.theme().colors().border_variant)
856 })
857 .child(
858 h_flex()
859 .gap_2()
860 .child(
861 Icon::new(IconName::Info)
862 .size(IconSize::XSmall)
863 .color(Color::Muted),
864 )
865 .child(Label::new("Zed also supports OpenAI-compatible models.")),
866 )
867 .child(
868 Button::new("docs", "Learn More")
869 .icon(IconName::ArrowUpRight)
870 .icon_size(IconSize::Small)
871 .icon_color(Color::Muted)
872 .on_click(move |_, _window, cx| {
873 cx.open_url("https://zed.dev/docs/ai/llm-providers#openai-api-compatible")
874 }),
875 );
876
877 if self.load_credentials_task.is_some() {
878 div().child(Label::new("Loading credentials…")).into_any()
879 } else {
880 v_flex()
881 .size_full()
882 .child(api_key_section)
883 .child(compatible_api_section)
884 .into_any()
885 }
886 }
887}
888
889#[cfg(test)]
890mod tests {
891 use gpui::TestAppContext;
892 use language_model::LanguageModelRequestMessage;
893
894 use super::*;
895
896 #[gpui::test]
897 fn tiktoken_rs_support(cx: &TestAppContext) {
898 let request = LanguageModelRequest {
899 thread_id: None,
900 prompt_id: None,
901 intent: None,
902 mode: None,
903 messages: vec![LanguageModelRequestMessage {
904 role: Role::User,
905 content: vec![MessageContent::Text("message".into())],
906 cache: false,
907 }],
908 tools: vec![],
909 tool_choice: None,
910 stop: vec![],
911 temperature: None,
912 thinking_allowed: true,
913 };
914
915 // Validate that all models are supported by tiktoken-rs
916 for model in Model::iter() {
917 let count = cx
918 .executor()
919 .block(count_open_ai_tokens(
920 request.clone(),
921 model,
922 &cx.app.borrow(),
923 ))
924 .unwrap();
925 assert!(count > 0);
926 }
927 }
928}