1use anyhow::{Result, anyhow};
2use collections::{BTreeMap, HashMap};
3use futures::Stream;
4use futures::{FutureExt, StreamExt, future, future::BoxFuture};
5use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
6use http_client::HttpClient;
7use language_model::{
8 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
9 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
10 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
11 LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
12 RateLimiter, Role, StopReason, TokenUsage,
13};
14use menu;
15use open_ai::{ImageUrl, Model, ReasoningEffort, ResponseStreamEvent, stream_completion};
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use settings::{Settings, SettingsStore};
19use std::pin::Pin;
20use std::str::FromStr as _;
21use std::sync::{Arc, LazyLock};
22use strum::IntoEnumIterator;
23use ui::{ElevationIndex, List, Tooltip, prelude::*};
24use ui_input::SingleLineInput;
25use util::ResultExt;
26use zed_env_vars::{EnvVar, env_var};
27
28use crate::{api_key::ApiKeyState, ui::InstructionListItem};
29
30const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
31const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME;
32
33const API_KEY_ENV_VAR_NAME: &str = "OPENAI_API_KEY";
34static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
35
36#[derive(Default, Clone, Debug, PartialEq)]
37pub struct OpenAiSettings {
38 pub api_url: String,
39 pub available_models: Vec<AvailableModel>,
40}
41
42#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
43pub struct AvailableModel {
44 pub name: String,
45 pub display_name: Option<String>,
46 pub max_tokens: u64,
47 pub max_output_tokens: Option<u64>,
48 pub max_completion_tokens: Option<u64>,
49 pub reasoning_effort: Option<ReasoningEffort>,
50}
51
52pub struct OpenAiLanguageModelProvider {
53 http_client: Arc<dyn HttpClient>,
54 state: gpui::Entity<State>,
55}
56
57pub struct State {
58 api_key_state: ApiKeyState,
59}
60
61impl State {
62 fn is_authenticated(&self) -> bool {
63 self.api_key_state.has_key()
64 }
65
66 fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
67 let api_url = OpenAiLanguageModelProvider::api_url(cx);
68 self.api_key_state
69 .store(api_url, api_key, |this| &mut this.api_key_state, cx)
70 }
71
72 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
73 let api_url = OpenAiLanguageModelProvider::api_url(cx);
74 self.api_key_state.load_if_needed(
75 api_url,
76 &API_KEY_ENV_VAR,
77 |this| &mut this.api_key_state,
78 cx,
79 )
80 }
81}
82
83impl OpenAiLanguageModelProvider {
84 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
85 let state = cx.new(|cx| {
86 cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
87 let api_url = Self::api_url(cx);
88 this.api_key_state.handle_url_change(
89 api_url,
90 &API_KEY_ENV_VAR,
91 |this| &mut this.api_key_state,
92 cx,
93 );
94 cx.notify();
95 })
96 .detach();
97 State {
98 api_key_state: ApiKeyState::new(Self::api_url(cx)),
99 }
100 });
101
102 Self { http_client, state }
103 }
104
105 fn create_language_model(&self, model: open_ai::Model) -> Arc<dyn LanguageModel> {
106 Arc::new(OpenAiLanguageModel {
107 id: LanguageModelId::from(model.id().to_string()),
108 model,
109 state: self.state.clone(),
110 http_client: self.http_client.clone(),
111 request_limiter: RateLimiter::new(4),
112 })
113 }
114
115 fn settings(cx: &App) -> &OpenAiSettings {
116 &crate::AllLanguageModelSettings::get_global(cx).openai
117 }
118
119 fn api_url(cx: &App) -> SharedString {
120 let api_url = &Self::settings(cx).api_url;
121 if api_url.is_empty() {
122 open_ai::OPEN_AI_API_URL.into()
123 } else {
124 SharedString::new(api_url.as_str())
125 }
126 }
127}
128
129impl LanguageModelProviderState for OpenAiLanguageModelProvider {
130 type ObservableEntity = State;
131
132 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
133 Some(self.state.clone())
134 }
135}
136
137impl LanguageModelProvider for OpenAiLanguageModelProvider {
138 fn id(&self) -> LanguageModelProviderId {
139 PROVIDER_ID
140 }
141
142 fn name(&self) -> LanguageModelProviderName {
143 PROVIDER_NAME
144 }
145
146 fn icon(&self) -> IconName {
147 IconName::AiOpenAi
148 }
149
150 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
151 Some(self.create_language_model(open_ai::Model::default()))
152 }
153
154 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
155 Some(self.create_language_model(open_ai::Model::default_fast()))
156 }
157
158 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
159 let mut models = BTreeMap::default();
160
161 // Add base models from open_ai::Model::iter()
162 for model in open_ai::Model::iter() {
163 if !matches!(model, open_ai::Model::Custom { .. }) {
164 models.insert(model.id().to_string(), model);
165 }
166 }
167
168 // Override with available models from settings
169 for model in &OpenAiLanguageModelProvider::settings(cx).available_models {
170 models.insert(
171 model.name.clone(),
172 open_ai::Model::Custom {
173 name: model.name.clone(),
174 display_name: model.display_name.clone(),
175 max_tokens: model.max_tokens,
176 max_output_tokens: model.max_output_tokens,
177 max_completion_tokens: model.max_completion_tokens,
178 reasoning_effort: model.reasoning_effort.clone(),
179 },
180 );
181 }
182
183 models
184 .into_values()
185 .map(|model| self.create_language_model(model))
186 .collect()
187 }
188
189 fn is_authenticated(&self, cx: &App) -> bool {
190 self.state.read(cx).is_authenticated()
191 }
192
193 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
194 self.state.update(cx, |state, cx| state.authenticate(cx))
195 }
196
197 fn configuration_view(
198 &self,
199 _target_agent: language_model::ConfigurationViewTargetAgent,
200 window: &mut Window,
201 cx: &mut App,
202 ) -> AnyView {
203 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
204 .into()
205 }
206
207 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
208 self.state
209 .update(cx, |state, cx| state.set_api_key(None, cx))
210 }
211}
212
213pub struct OpenAiLanguageModel {
214 id: LanguageModelId,
215 model: open_ai::Model,
216 state: gpui::Entity<State>,
217 http_client: Arc<dyn HttpClient>,
218 request_limiter: RateLimiter,
219}
220
221impl OpenAiLanguageModel {
222 fn stream_completion(
223 &self,
224 request: open_ai::Request,
225 cx: &AsyncApp,
226 ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
227 {
228 let http_client = self.http_client.clone();
229
230 let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
231 let api_url = OpenAiLanguageModelProvider::api_url(cx);
232 (state.api_key_state.key(&api_url), api_url)
233 }) else {
234 return future::ready(Err(anyhow!("App state dropped"))).boxed();
235 };
236
237 let future = self.request_limiter.stream(async move {
238 let Some(api_key) = api_key else {
239 return Err(LanguageModelCompletionError::NoApiKey {
240 provider: PROVIDER_NAME,
241 });
242 };
243 let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
244 let response = request.await?;
245 Ok(response)
246 });
247
248 async move { Ok(future.await?.boxed()) }.boxed()
249 }
250}
251
252impl LanguageModel for OpenAiLanguageModel {
253 fn id(&self) -> LanguageModelId {
254 self.id.clone()
255 }
256
257 fn name(&self) -> LanguageModelName {
258 LanguageModelName::from(self.model.display_name().to_string())
259 }
260
261 fn provider_id(&self) -> LanguageModelProviderId {
262 PROVIDER_ID
263 }
264
265 fn provider_name(&self) -> LanguageModelProviderName {
266 PROVIDER_NAME
267 }
268
269 fn supports_tools(&self) -> bool {
270 true
271 }
272
273 fn supports_images(&self) -> bool {
274 use open_ai::Model;
275 match &self.model {
276 Model::FourOmni
277 | Model::FourOmniMini
278 | Model::FourPointOne
279 | Model::FourPointOneMini
280 | Model::FourPointOneNano
281 | Model::Five
282 | Model::FiveMini
283 | Model::FiveNano
284 | Model::O1
285 | Model::O3
286 | Model::O4Mini => true,
287 Model::ThreePointFiveTurbo
288 | Model::Four
289 | Model::FourTurbo
290 | Model::O3Mini
291 | Model::Custom { .. } => false,
292 }
293 }
294
295 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
296 match choice {
297 LanguageModelToolChoice::Auto => true,
298 LanguageModelToolChoice::Any => true,
299 LanguageModelToolChoice::None => true,
300 }
301 }
302
303 fn telemetry_id(&self) -> String {
304 format!("openai/{}", self.model.id())
305 }
306
307 fn max_token_count(&self) -> u64 {
308 self.model.max_token_count()
309 }
310
311 fn max_output_tokens(&self) -> Option<u64> {
312 self.model.max_output_tokens()
313 }
314
315 fn count_tokens(
316 &self,
317 request: LanguageModelRequest,
318 cx: &App,
319 ) -> BoxFuture<'static, Result<u64>> {
320 count_open_ai_tokens(request, self.model.clone(), cx)
321 }
322
323 fn stream_completion(
324 &self,
325 request: LanguageModelRequest,
326 cx: &AsyncApp,
327 ) -> BoxFuture<
328 'static,
329 Result<
330 futures::stream::BoxStream<
331 'static,
332 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
333 >,
334 LanguageModelCompletionError,
335 >,
336 > {
337 let request = into_open_ai(
338 request,
339 self.model.id(),
340 self.model.supports_parallel_tool_calls(),
341 self.model.supports_prompt_cache_key(),
342 self.max_output_tokens(),
343 self.model.reasoning_effort(),
344 );
345 let completions = self.stream_completion(request, cx);
346 async move {
347 let mapper = OpenAiEventMapper::new();
348 Ok(mapper.map_stream(completions.await?).boxed())
349 }
350 .boxed()
351 }
352}
353
354pub fn into_open_ai(
355 request: LanguageModelRequest,
356 model_id: &str,
357 supports_parallel_tool_calls: bool,
358 supports_prompt_cache_key: bool,
359 max_output_tokens: Option<u64>,
360 reasoning_effort: Option<ReasoningEffort>,
361) -> open_ai::Request {
362 let stream = !model_id.starts_with("o1-");
363
364 let mut messages = Vec::new();
365 for message in request.messages {
366 for content in message.content {
367 match content {
368 MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
369 add_message_content_part(
370 open_ai::MessagePart::Text { text },
371 message.role,
372 &mut messages,
373 )
374 }
375 MessageContent::RedactedThinking(_) => {}
376 MessageContent::Image(image) => {
377 add_message_content_part(
378 open_ai::MessagePart::Image {
379 image_url: ImageUrl {
380 url: image.to_base64_url(),
381 detail: None,
382 },
383 },
384 message.role,
385 &mut messages,
386 );
387 }
388 MessageContent::ToolUse(tool_use) => {
389 let tool_call = open_ai::ToolCall {
390 id: tool_use.id.to_string(),
391 content: open_ai::ToolCallContent::Function {
392 function: open_ai::FunctionContent {
393 name: tool_use.name.to_string(),
394 arguments: serde_json::to_string(&tool_use.input)
395 .unwrap_or_default(),
396 },
397 },
398 };
399
400 if let Some(open_ai::RequestMessage::Assistant { tool_calls, .. }) =
401 messages.last_mut()
402 {
403 tool_calls.push(tool_call);
404 } else {
405 messages.push(open_ai::RequestMessage::Assistant {
406 content: None,
407 tool_calls: vec![tool_call],
408 });
409 }
410 }
411 MessageContent::ToolResult(tool_result) => {
412 let content = match &tool_result.content {
413 LanguageModelToolResultContent::Text(text) => {
414 vec![open_ai::MessagePart::Text {
415 text: text.to_string(),
416 }]
417 }
418 LanguageModelToolResultContent::Image(image) => {
419 vec![open_ai::MessagePart::Image {
420 image_url: ImageUrl {
421 url: image.to_base64_url(),
422 detail: None,
423 },
424 }]
425 }
426 };
427
428 messages.push(open_ai::RequestMessage::Tool {
429 content: content.into(),
430 tool_call_id: tool_result.tool_use_id.to_string(),
431 });
432 }
433 }
434 }
435 }
436
437 open_ai::Request {
438 model: model_id.into(),
439 messages,
440 stream,
441 stop: request.stop,
442 temperature: request.temperature.unwrap_or(1.0),
443 max_completion_tokens: max_output_tokens,
444 parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
445 // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
446 Some(false)
447 } else {
448 None
449 },
450 prompt_cache_key: if supports_prompt_cache_key {
451 request.thread_id
452 } else {
453 None
454 },
455 tools: request
456 .tools
457 .into_iter()
458 .map(|tool| open_ai::ToolDefinition::Function {
459 function: open_ai::FunctionDefinition {
460 name: tool.name,
461 description: Some(tool.description),
462 parameters: Some(tool.input_schema),
463 },
464 })
465 .collect(),
466 tool_choice: request.tool_choice.map(|choice| match choice {
467 LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto,
468 LanguageModelToolChoice::Any => open_ai::ToolChoice::Required,
469 LanguageModelToolChoice::None => open_ai::ToolChoice::None,
470 }),
471 reasoning_effort,
472 }
473}
474
475fn add_message_content_part(
476 new_part: open_ai::MessagePart,
477 role: Role,
478 messages: &mut Vec<open_ai::RequestMessage>,
479) {
480 match (role, messages.last_mut()) {
481 (Role::User, Some(open_ai::RequestMessage::User { content }))
482 | (
483 Role::Assistant,
484 Some(open_ai::RequestMessage::Assistant {
485 content: Some(content),
486 ..
487 }),
488 )
489 | (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => {
490 content.push_part(new_part);
491 }
492 _ => {
493 messages.push(match role {
494 Role::User => open_ai::RequestMessage::User {
495 content: open_ai::MessageContent::from(vec![new_part]),
496 },
497 Role::Assistant => open_ai::RequestMessage::Assistant {
498 content: Some(open_ai::MessageContent::from(vec![new_part])),
499 tool_calls: Vec::new(),
500 },
501 Role::System => open_ai::RequestMessage::System {
502 content: open_ai::MessageContent::from(vec![new_part]),
503 },
504 });
505 }
506 }
507}
508
509pub struct OpenAiEventMapper {
510 tool_calls_by_index: HashMap<usize, RawToolCall>,
511}
512
513impl OpenAiEventMapper {
514 pub fn new() -> Self {
515 Self {
516 tool_calls_by_index: HashMap::default(),
517 }
518 }
519
520 pub fn map_stream(
521 mut self,
522 events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
523 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
524 {
525 events.flat_map(move |event| {
526 futures::stream::iter(match event {
527 Ok(event) => self.map_event(event),
528 Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
529 })
530 })
531 }
532
533 pub fn map_event(
534 &mut self,
535 event: ResponseStreamEvent,
536 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
537 let mut events = Vec::new();
538 if let Some(usage) = event.usage {
539 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
540 input_tokens: usage.prompt_tokens,
541 output_tokens: usage.completion_tokens,
542 cache_creation_input_tokens: 0,
543 cache_read_input_tokens: 0,
544 })));
545 }
546
547 let Some(choice) = event.choices.first() else {
548 return events;
549 };
550
551 if let Some(content) = choice.delta.content.clone() {
552 if !content.is_empty() {
553 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
554 }
555 }
556
557 if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
558 for tool_call in tool_calls {
559 let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
560
561 if let Some(tool_id) = tool_call.id.clone() {
562 entry.id = tool_id;
563 }
564
565 if let Some(function) = tool_call.function.as_ref() {
566 if let Some(name) = function.name.clone() {
567 entry.name = name;
568 }
569
570 if let Some(arguments) = function.arguments.clone() {
571 entry.arguments.push_str(&arguments);
572 }
573 }
574 }
575 }
576
577 match choice.finish_reason.as_deref() {
578 Some("stop") => {
579 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
580 }
581 Some("tool_calls") => {
582 events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
583 match serde_json::Value::from_str(&tool_call.arguments) {
584 Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
585 LanguageModelToolUse {
586 id: tool_call.id.clone().into(),
587 name: tool_call.name.as_str().into(),
588 is_input_complete: true,
589 input,
590 raw_input: tool_call.arguments.clone(),
591 },
592 )),
593 Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
594 id: tool_call.id.into(),
595 tool_name: tool_call.name.into(),
596 raw_input: tool_call.arguments.clone().into(),
597 json_parse_error: error.to_string(),
598 }),
599 }
600 }));
601
602 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
603 }
604 Some(stop_reason) => {
605 log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
606 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
607 }
608 None => {}
609 }
610
611 events
612 }
613}
614
615#[derive(Default)]
616struct RawToolCall {
617 id: String,
618 name: String,
619 arguments: String,
620}
621
622pub(crate) fn collect_tiktoken_messages(
623 request: LanguageModelRequest,
624) -> Vec<tiktoken_rs::ChatCompletionRequestMessage> {
625 request
626 .messages
627 .into_iter()
628 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
629 role: match message.role {
630 Role::User => "user".into(),
631 Role::Assistant => "assistant".into(),
632 Role::System => "system".into(),
633 },
634 content: Some(message.string_contents()),
635 name: None,
636 function_call: None,
637 })
638 .collect::<Vec<_>>()
639}
640
641pub fn count_open_ai_tokens(
642 request: LanguageModelRequest,
643 model: Model,
644 cx: &App,
645) -> BoxFuture<'static, Result<u64>> {
646 cx.background_spawn(async move {
647 let messages = collect_tiktoken_messages(request);
648
649 match model {
650 Model::Custom { max_tokens, .. } => {
651 let model = if max_tokens >= 100_000 {
652 // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
653 "gpt-4o"
654 } else {
655 // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
656 // supported with this tiktoken method
657 "gpt-4"
658 };
659 tiktoken_rs::num_tokens_from_messages(model, &messages)
660 }
661 // Currently supported by tiktoken_rs
662 // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch
663 // arm with an override. We enumerate all supported models here so that we can check if new
664 // models are supported yet or not.
665 Model::ThreePointFiveTurbo
666 | Model::Four
667 | Model::FourTurbo
668 | Model::FourOmni
669 | Model::FourOmniMini
670 | Model::FourPointOne
671 | Model::FourPointOneMini
672 | Model::FourPointOneNano
673 | Model::O1
674 | Model::O3
675 | Model::O3Mini
676 | Model::O4Mini => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
677 // GPT-5 models don't have tiktoken support yet; fall back on gpt-4o tokenizer
678 Model::Five | Model::FiveMini | Model::FiveNano => {
679 tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages)
680 }
681 }
682 .map(|tokens| tokens as u64)
683 })
684 .boxed()
685}
686
687struct ConfigurationView {
688 api_key_editor: Entity<SingleLineInput>,
689 state: gpui::Entity<State>,
690 load_credentials_task: Option<Task<()>>,
691}
692
693impl ConfigurationView {
694 fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
695 let api_key_editor = cx.new(|cx| {
696 SingleLineInput::new(
697 window,
698 cx,
699 "sk-000000000000000000000000000000000000000000000000",
700 )
701 });
702
703 cx.observe(&state, |_, _, cx| {
704 cx.notify();
705 })
706 .detach();
707
708 let load_credentials_task = Some(cx.spawn_in(window, {
709 let state = state.clone();
710 async move |this, cx| {
711 if let Some(task) = state
712 .update(cx, |state, cx| state.authenticate(cx))
713 .log_err()
714 {
715 // We don't log an error, because "not signed in" is also an error.
716 let _ = task.await;
717 }
718 this.update(cx, |this, cx| {
719 this.load_credentials_task = None;
720 cx.notify();
721 })
722 .log_err();
723 }
724 }));
725
726 Self {
727 api_key_editor,
728 state,
729 load_credentials_task,
730 }
731 }
732
733 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
734 let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
735 if api_key.is_empty() {
736 return;
737 }
738
739 // url changes can cause the editor to be displayed again
740 self.api_key_editor
741 .update(cx, |editor, cx| editor.set_text("", window, cx));
742
743 let state = self.state.clone();
744 cx.spawn_in(window, async move |_, cx| {
745 state
746 .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
747 .await
748 })
749 .detach_and_log_err(cx);
750 }
751
752 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
753 self.api_key_editor
754 .update(cx, |input, cx| input.set_text("", window, cx));
755
756 let state = self.state.clone();
757 cx.spawn_in(window, async move |_, cx| {
758 state
759 .update(cx, |state, cx| state.set_api_key(None, cx))?
760 .await
761 })
762 .detach_and_log_err(cx);
763 }
764
765 fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
766 !self.state.read(cx).is_authenticated()
767 }
768}
769
770impl Render for ConfigurationView {
771 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
772 let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
773
774 let api_key_section = if self.should_render_editor(cx) {
775 v_flex()
776 .on_action(cx.listener(Self::save_api_key))
777 .child(Label::new("To use Zed's agent with OpenAI, you need to add an API key. Follow these steps:"))
778 .child(
779 List::new()
780 .child(InstructionListItem::new(
781 "Create one by visiting",
782 Some("OpenAI's console"),
783 Some("https://platform.openai.com/api-keys"),
784 ))
785 .child(InstructionListItem::text_only(
786 "Ensure your OpenAI account has credits",
787 ))
788 .child(InstructionListItem::text_only(
789 "Paste your API key below and hit enter to start using the assistant",
790 )),
791 )
792 .child(self.api_key_editor.clone())
793 .child(
794 Label::new(format!(
795 "You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
796 ))
797 .size(LabelSize::Small)
798 .color(Color::Muted),
799 )
800 .child(
801 Label::new(
802 "Note that having a subscription for another service like GitHub Copilot won't work.",
803 )
804 .size(LabelSize::Small).color(Color::Muted),
805 )
806 .into_any()
807 } else {
808 h_flex()
809 .mt_1()
810 .p_1()
811 .justify_between()
812 .rounded_md()
813 .border_1()
814 .border_color(cx.theme().colors().border)
815 .bg(cx.theme().colors().background)
816 .child(
817 h_flex()
818 .gap_1()
819 .child(Icon::new(IconName::Check).color(Color::Success))
820 .child(Label::new(if env_var_set {
821 format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.")
822 } else {
823 "API key configured.".to_string()
824 })),
825 )
826 .child(
827 Button::new("reset-api-key", "Reset API Key")
828 .label_size(LabelSize::Small)
829 .icon(IconName::Undo)
830 .icon_size(IconSize::Small)
831 .icon_position(IconPosition::Start)
832 .layer(ElevationIndex::ModalSurface)
833 .when(env_var_set, |this| {
834 this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
835 })
836 .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
837 )
838 .into_any()
839 };
840
841 let compatible_api_section = h_flex()
842 .mt_1p5()
843 .gap_0p5()
844 .flex_wrap()
845 .when(self.should_render_editor(cx), |this| {
846 this.pt_1p5()
847 .border_t_1()
848 .border_color(cx.theme().colors().border_variant)
849 })
850 .child(
851 h_flex()
852 .gap_2()
853 .child(
854 Icon::new(IconName::Info)
855 .size(IconSize::XSmall)
856 .color(Color::Muted),
857 )
858 .child(Label::new("Zed also supports OpenAI-compatible models.")),
859 )
860 .child(
861 Button::new("docs", "Learn More")
862 .icon(IconName::ArrowUpRight)
863 .icon_size(IconSize::Small)
864 .icon_color(Color::Muted)
865 .on_click(move |_, _window, cx| {
866 cx.open_url("https://zed.dev/docs/ai/llm-providers#openai-api-compatible")
867 }),
868 );
869
870 if self.load_credentials_task.is_some() {
871 div().child(Label::new("Loading credentials…")).into_any()
872 } else {
873 v_flex()
874 .size_full()
875 .child(api_key_section)
876 .child(compatible_api_section)
877 .into_any()
878 }
879 }
880}
881
882#[cfg(test)]
883mod tests {
884 use gpui::TestAppContext;
885 use language_model::LanguageModelRequestMessage;
886
887 use super::*;
888
889 #[gpui::test]
890 fn tiktoken_rs_support(cx: &TestAppContext) {
891 let request = LanguageModelRequest {
892 thread_id: None,
893 prompt_id: None,
894 intent: None,
895 mode: None,
896 messages: vec![LanguageModelRequestMessage {
897 role: Role::User,
898 content: vec![MessageContent::Text("message".into())],
899 cache: false,
900 }],
901 tools: vec![],
902 tool_choice: None,
903 stop: vec![],
904 temperature: None,
905 thinking_allowed: true,
906 };
907
908 // Validate that all models are supported by tiktoken-rs
909 for model in Model::iter() {
910 let count = cx
911 .executor()
912 .block(count_open_ai_tokens(
913 request.clone(),
914 model,
915 &cx.app.borrow(),
916 ))
917 .unwrap();
918 assert!(count > 0);
919 }
920 }
921}