1use anyhow::{Result, anyhow};
2use collections::{BTreeMap, HashMap};
3use futures::Stream;
4use futures::{FutureExt, StreamExt, future::BoxFuture};
5use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
6use http_client::HttpClient;
7use language_model::{
8 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
9 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
10 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
11 LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
12 RateLimiter, Role, StopReason, TokenUsage,
13};
14use menu;
15use open_ai::{ImageUrl, Model, ReasoningEffort, ResponseStreamEvent, stream_completion};
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use settings::{Settings, SettingsStore};
19use std::pin::Pin;
20use std::str::FromStr as _;
21use std::sync::{Arc, LazyLock};
22use strum::IntoEnumIterator;
23use ui::{ElevationIndex, List, Tooltip, prelude::*};
24use ui_input::SingleLineInput;
25use util::ResultExt;
26use zed_env_vars::{EnvVar, env_var};
27
28use crate::{AllLanguageModelSettings, api_key::ApiKeyState, ui::InstructionListItem};
29
30const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
31const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME;
32
33const API_KEY_ENV_VAR_NAME: &str = "OPENAI_API_KEY";
34static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
35
36#[derive(Default, Clone, Debug, PartialEq)]
37pub struct OpenAiSettings {
38 pub api_url: String,
39 pub available_models: Vec<AvailableModel>,
40}
41
42#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
43pub struct AvailableModel {
44 pub name: String,
45 pub display_name: Option<String>,
46 pub max_tokens: u64,
47 pub max_output_tokens: Option<u64>,
48 pub max_completion_tokens: Option<u64>,
49 pub reasoning_effort: Option<ReasoningEffort>,
50}
51
52pub struct OpenAiLanguageModelProvider {
53 http_client: Arc<dyn HttpClient>,
54 state: gpui::Entity<State>,
55}
56
57pub struct State {
58 api_key_state: ApiKeyState,
59}
60
61impl State {
62 fn is_authenticated(&self) -> bool {
63 self.api_key_state.has_key()
64 }
65
66 fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
67 let api_url = OpenAiLanguageModelProvider::api_url(cx);
68 self.api_key_state
69 .store(api_url, api_key, |this| &mut this.api_key_state, cx)
70 }
71
72 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
73 let api_url = OpenAiLanguageModelProvider::api_url(cx);
74 self.api_key_state.load_if_needed(
75 api_url,
76 &API_KEY_ENV_VAR,
77 |this| &mut this.api_key_state,
78 cx,
79 )
80 }
81}
82
83impl OpenAiLanguageModelProvider {
84 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
85 let state = cx.new(|cx| {
86 cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
87 let api_url = Self::api_url(cx);
88 this.api_key_state.handle_url_change(
89 api_url,
90 &API_KEY_ENV_VAR,
91 |this| &mut this.api_key_state,
92 cx,
93 );
94 cx.notify();
95 })
96 .detach();
97 State {
98 api_key_state: ApiKeyState::new(Self::api_url(cx)),
99 }
100 });
101
102 Self { http_client, state }
103 }
104
105 fn create_language_model(&self, model: open_ai::Model) -> Arc<dyn LanguageModel> {
106 Arc::new(OpenAiLanguageModel {
107 id: LanguageModelId::from(model.id().to_string()),
108 model,
109 state: self.state.clone(),
110 http_client: self.http_client.clone(),
111 request_limiter: RateLimiter::new(4),
112 })
113 }
114
115 fn settings(cx: &App) -> &OpenAiSettings {
116 &AllLanguageModelSettings::get_global(cx).openai
117 }
118
119 fn api_url(cx: &App) -> SharedString {
120 let api_url = &Self::settings(cx).api_url;
121 if api_url.is_empty() {
122 open_ai::OPEN_AI_API_URL.into()
123 } else {
124 SharedString::new(api_url.as_str())
125 }
126 }
127}
128
129impl LanguageModelProviderState for OpenAiLanguageModelProvider {
130 type ObservableEntity = State;
131
132 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
133 Some(self.state.clone())
134 }
135}
136
137impl LanguageModelProvider for OpenAiLanguageModelProvider {
138 fn id(&self) -> LanguageModelProviderId {
139 PROVIDER_ID
140 }
141
142 fn name(&self) -> LanguageModelProviderName {
143 PROVIDER_NAME
144 }
145
146 fn icon(&self) -> IconName {
147 IconName::AiOpenAi
148 }
149
150 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
151 Some(self.create_language_model(open_ai::Model::default()))
152 }
153
154 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
155 Some(self.create_language_model(open_ai::Model::default_fast()))
156 }
157
158 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
159 let mut models = BTreeMap::default();
160
161 // Add base models from open_ai::Model::iter()
162 for model in open_ai::Model::iter() {
163 if !matches!(model, open_ai::Model::Custom { .. }) {
164 models.insert(model.id().to_string(), model);
165 }
166 }
167
168 // Override with available models from settings
169 for model in &AllLanguageModelSettings::get_global(cx)
170 .openai
171 .available_models
172 {
173 models.insert(
174 model.name.clone(),
175 open_ai::Model::Custom {
176 name: model.name.clone(),
177 display_name: model.display_name.clone(),
178 max_tokens: model.max_tokens,
179 max_output_tokens: model.max_output_tokens,
180 max_completion_tokens: model.max_completion_tokens,
181 reasoning_effort: model.reasoning_effort.clone(),
182 },
183 );
184 }
185
186 models
187 .into_values()
188 .map(|model| self.create_language_model(model))
189 .collect()
190 }
191
192 fn is_authenticated(&self, cx: &App) -> bool {
193 self.state.read(cx).is_authenticated()
194 }
195
196 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
197 self.state.update(cx, |state, cx| state.authenticate(cx))
198 }
199
200 fn configuration_view(
201 &self,
202 _target_agent: language_model::ConfigurationViewTargetAgent,
203 window: &mut Window,
204 cx: &mut App,
205 ) -> AnyView {
206 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
207 .into()
208 }
209
210 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
211 self.state
212 .update(cx, |state, cx| state.set_api_key(None, cx))
213 }
214}
215
216pub struct OpenAiLanguageModel {
217 id: LanguageModelId,
218 model: open_ai::Model,
219 state: gpui::Entity<State>,
220 http_client: Arc<dyn HttpClient>,
221 request_limiter: RateLimiter,
222}
223
224impl OpenAiLanguageModel {
225 fn stream_completion(
226 &self,
227 request: open_ai::Request,
228 cx: &AsyncApp,
229 ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
230 {
231 let http_client = self.http_client.clone();
232
233 let api_key_and_url = self.state.read_with(cx, |state, cx| {
234 let api_url = OpenAiLanguageModelProvider::api_url(cx);
235 let api_key = state.api_key_state.key(&api_url);
236 (api_key, api_url)
237 });
238 let (api_key, api_url) = match api_key_and_url {
239 Ok(api_key_and_url) => api_key_and_url,
240 Err(err) => {
241 return futures::future::ready(Err(err)).boxed();
242 }
243 };
244
245 let future = self.request_limiter.stream(async move {
246 let Some(api_key) = api_key else {
247 return Err(LanguageModelCompletionError::NoApiKey {
248 provider: PROVIDER_NAME,
249 });
250 };
251 let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
252 let response = request.await?;
253 Ok(response)
254 });
255
256 async move { Ok(future.await?.boxed()) }.boxed()
257 }
258}
259
260impl LanguageModel for OpenAiLanguageModel {
261 fn id(&self) -> LanguageModelId {
262 self.id.clone()
263 }
264
265 fn name(&self) -> LanguageModelName {
266 LanguageModelName::from(self.model.display_name().to_string())
267 }
268
269 fn provider_id(&self) -> LanguageModelProviderId {
270 PROVIDER_ID
271 }
272
273 fn provider_name(&self) -> LanguageModelProviderName {
274 PROVIDER_NAME
275 }
276
277 fn supports_tools(&self) -> bool {
278 true
279 }
280
281 fn supports_images(&self) -> bool {
282 use open_ai::Model;
283 match &self.model {
284 Model::FourOmni
285 | Model::FourOmniMini
286 | Model::FourPointOne
287 | Model::FourPointOneMini
288 | Model::FourPointOneNano
289 | Model::Five
290 | Model::FiveMini
291 | Model::FiveNano
292 | Model::O1
293 | Model::O3
294 | Model::O4Mini => true,
295 Model::ThreePointFiveTurbo
296 | Model::Four
297 | Model::FourTurbo
298 | Model::O3Mini
299 | Model::Custom { .. } => false,
300 }
301 }
302
303 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
304 match choice {
305 LanguageModelToolChoice::Auto => true,
306 LanguageModelToolChoice::Any => true,
307 LanguageModelToolChoice::None => true,
308 }
309 }
310
311 fn telemetry_id(&self) -> String {
312 format!("openai/{}", self.model.id())
313 }
314
315 fn max_token_count(&self) -> u64 {
316 self.model.max_token_count()
317 }
318
319 fn max_output_tokens(&self) -> Option<u64> {
320 self.model.max_output_tokens()
321 }
322
323 fn count_tokens(
324 &self,
325 request: LanguageModelRequest,
326 cx: &App,
327 ) -> BoxFuture<'static, Result<u64>> {
328 count_open_ai_tokens(request, self.model.clone(), cx)
329 }
330
331 fn stream_completion(
332 &self,
333 request: LanguageModelRequest,
334 cx: &AsyncApp,
335 ) -> BoxFuture<
336 'static,
337 Result<
338 futures::stream::BoxStream<
339 'static,
340 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
341 >,
342 LanguageModelCompletionError,
343 >,
344 > {
345 let request = into_open_ai(
346 request,
347 self.model.id(),
348 self.model.supports_parallel_tool_calls(),
349 self.model.supports_prompt_cache_key(),
350 self.max_output_tokens(),
351 self.model.reasoning_effort(),
352 );
353 let completions = self.stream_completion(request, cx);
354 async move {
355 let mapper = OpenAiEventMapper::new();
356 Ok(mapper.map_stream(completions.await?).boxed())
357 }
358 .boxed()
359 }
360}
361
362pub fn into_open_ai(
363 request: LanguageModelRequest,
364 model_id: &str,
365 supports_parallel_tool_calls: bool,
366 supports_prompt_cache_key: bool,
367 max_output_tokens: Option<u64>,
368 reasoning_effort: Option<ReasoningEffort>,
369) -> open_ai::Request {
370 let stream = !model_id.starts_with("o1-");
371
372 let mut messages = Vec::new();
373 for message in request.messages {
374 for content in message.content {
375 match content {
376 MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
377 add_message_content_part(
378 open_ai::MessagePart::Text { text },
379 message.role,
380 &mut messages,
381 )
382 }
383 MessageContent::RedactedThinking(_) => {}
384 MessageContent::Image(image) => {
385 add_message_content_part(
386 open_ai::MessagePart::Image {
387 image_url: ImageUrl {
388 url: image.to_base64_url(),
389 detail: None,
390 },
391 },
392 message.role,
393 &mut messages,
394 );
395 }
396 MessageContent::ToolUse(tool_use) => {
397 let tool_call = open_ai::ToolCall {
398 id: tool_use.id.to_string(),
399 content: open_ai::ToolCallContent::Function {
400 function: open_ai::FunctionContent {
401 name: tool_use.name.to_string(),
402 arguments: serde_json::to_string(&tool_use.input)
403 .unwrap_or_default(),
404 },
405 },
406 };
407
408 if let Some(open_ai::RequestMessage::Assistant { tool_calls, .. }) =
409 messages.last_mut()
410 {
411 tool_calls.push(tool_call);
412 } else {
413 messages.push(open_ai::RequestMessage::Assistant {
414 content: None,
415 tool_calls: vec![tool_call],
416 });
417 }
418 }
419 MessageContent::ToolResult(tool_result) => {
420 let content = match &tool_result.content {
421 LanguageModelToolResultContent::Text(text) => {
422 vec![open_ai::MessagePart::Text {
423 text: text.to_string(),
424 }]
425 }
426 LanguageModelToolResultContent::Image(image) => {
427 vec![open_ai::MessagePart::Image {
428 image_url: ImageUrl {
429 url: image.to_base64_url(),
430 detail: None,
431 },
432 }]
433 }
434 };
435
436 messages.push(open_ai::RequestMessage::Tool {
437 content: content.into(),
438 tool_call_id: tool_result.tool_use_id.to_string(),
439 });
440 }
441 }
442 }
443 }
444
445 open_ai::Request {
446 model: model_id.into(),
447 messages,
448 stream,
449 stop: request.stop,
450 temperature: request.temperature.unwrap_or(1.0),
451 max_completion_tokens: max_output_tokens,
452 parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
453 // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
454 Some(false)
455 } else {
456 None
457 },
458 prompt_cache_key: if supports_prompt_cache_key {
459 request.thread_id
460 } else {
461 None
462 },
463 tools: request
464 .tools
465 .into_iter()
466 .map(|tool| open_ai::ToolDefinition::Function {
467 function: open_ai::FunctionDefinition {
468 name: tool.name,
469 description: Some(tool.description),
470 parameters: Some(tool.input_schema),
471 },
472 })
473 .collect(),
474 tool_choice: request.tool_choice.map(|choice| match choice {
475 LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto,
476 LanguageModelToolChoice::Any => open_ai::ToolChoice::Required,
477 LanguageModelToolChoice::None => open_ai::ToolChoice::None,
478 }),
479 reasoning_effort,
480 }
481}
482
483fn add_message_content_part(
484 new_part: open_ai::MessagePart,
485 role: Role,
486 messages: &mut Vec<open_ai::RequestMessage>,
487) {
488 match (role, messages.last_mut()) {
489 (Role::User, Some(open_ai::RequestMessage::User { content }))
490 | (
491 Role::Assistant,
492 Some(open_ai::RequestMessage::Assistant {
493 content: Some(content),
494 ..
495 }),
496 )
497 | (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => {
498 content.push_part(new_part);
499 }
500 _ => {
501 messages.push(match role {
502 Role::User => open_ai::RequestMessage::User {
503 content: open_ai::MessageContent::from(vec![new_part]),
504 },
505 Role::Assistant => open_ai::RequestMessage::Assistant {
506 content: Some(open_ai::MessageContent::from(vec![new_part])),
507 tool_calls: Vec::new(),
508 },
509 Role::System => open_ai::RequestMessage::System {
510 content: open_ai::MessageContent::from(vec![new_part]),
511 },
512 });
513 }
514 }
515}
516
517pub struct OpenAiEventMapper {
518 tool_calls_by_index: HashMap<usize, RawToolCall>,
519}
520
521impl OpenAiEventMapper {
522 pub fn new() -> Self {
523 Self {
524 tool_calls_by_index: HashMap::default(),
525 }
526 }
527
528 pub fn map_stream(
529 mut self,
530 events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
531 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
532 {
533 events.flat_map(move |event| {
534 futures::stream::iter(match event {
535 Ok(event) => self.map_event(event),
536 Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
537 })
538 })
539 }
540
541 pub fn map_event(
542 &mut self,
543 event: ResponseStreamEvent,
544 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
545 let mut events = Vec::new();
546 if let Some(usage) = event.usage {
547 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
548 input_tokens: usage.prompt_tokens,
549 output_tokens: usage.completion_tokens,
550 cache_creation_input_tokens: 0,
551 cache_read_input_tokens: 0,
552 })));
553 }
554
555 let Some(choice) = event.choices.first() else {
556 return events;
557 };
558
559 if let Some(content) = choice.delta.content.clone() {
560 if !content.is_empty() {
561 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
562 }
563 }
564
565 if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
566 for tool_call in tool_calls {
567 let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
568
569 if let Some(tool_id) = tool_call.id.clone() {
570 entry.id = tool_id;
571 }
572
573 if let Some(function) = tool_call.function.as_ref() {
574 if let Some(name) = function.name.clone() {
575 entry.name = name;
576 }
577
578 if let Some(arguments) = function.arguments.clone() {
579 entry.arguments.push_str(&arguments);
580 }
581 }
582 }
583 }
584
585 match choice.finish_reason.as_deref() {
586 Some("stop") => {
587 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
588 }
589 Some("tool_calls") => {
590 events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
591 match serde_json::Value::from_str(&tool_call.arguments) {
592 Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
593 LanguageModelToolUse {
594 id: tool_call.id.clone().into(),
595 name: tool_call.name.as_str().into(),
596 is_input_complete: true,
597 input,
598 raw_input: tool_call.arguments.clone(),
599 },
600 )),
601 Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
602 id: tool_call.id.into(),
603 tool_name: tool_call.name.into(),
604 raw_input: tool_call.arguments.clone().into(),
605 json_parse_error: error.to_string(),
606 }),
607 }
608 }));
609
610 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
611 }
612 Some(stop_reason) => {
613 log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
614 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
615 }
616 None => {}
617 }
618
619 events
620 }
621}
622
623#[derive(Default)]
624struct RawToolCall {
625 id: String,
626 name: String,
627 arguments: String,
628}
629
630pub(crate) fn collect_tiktoken_messages(
631 request: LanguageModelRequest,
632) -> Vec<tiktoken_rs::ChatCompletionRequestMessage> {
633 request
634 .messages
635 .into_iter()
636 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
637 role: match message.role {
638 Role::User => "user".into(),
639 Role::Assistant => "assistant".into(),
640 Role::System => "system".into(),
641 },
642 content: Some(message.string_contents()),
643 name: None,
644 function_call: None,
645 })
646 .collect::<Vec<_>>()
647}
648
649pub fn count_open_ai_tokens(
650 request: LanguageModelRequest,
651 model: Model,
652 cx: &App,
653) -> BoxFuture<'static, Result<u64>> {
654 cx.background_spawn(async move {
655 let messages = collect_tiktoken_messages(request);
656
657 match model {
658 Model::Custom { max_tokens, .. } => {
659 let model = if max_tokens >= 100_000 {
660 // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
661 "gpt-4o"
662 } else {
663 // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
664 // supported with this tiktoken method
665 "gpt-4"
666 };
667 tiktoken_rs::num_tokens_from_messages(model, &messages)
668 }
669 // Currently supported by tiktoken_rs
670 // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch
671 // arm with an override. We enumerate all supported models here so that we can check if new
672 // models are supported yet or not.
673 Model::ThreePointFiveTurbo
674 | Model::Four
675 | Model::FourTurbo
676 | Model::FourOmni
677 | Model::FourOmniMini
678 | Model::FourPointOne
679 | Model::FourPointOneMini
680 | Model::FourPointOneNano
681 | Model::O1
682 | Model::O3
683 | Model::O3Mini
684 | Model::O4Mini => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
685 // GPT-5 models don't have tiktoken support yet; fall back on gpt-4o tokenizer
686 Model::Five | Model::FiveMini | Model::FiveNano => {
687 tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages)
688 }
689 }
690 .map(|tokens| tokens as u64)
691 })
692 .boxed()
693}
694
695struct ConfigurationView {
696 api_key_editor: Entity<SingleLineInput>,
697 state: gpui::Entity<State>,
698 load_credentials_task: Option<Task<()>>,
699}
700
701impl ConfigurationView {
702 fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
703 let api_key_editor = cx.new(|cx| {
704 SingleLineInput::new(
705 window,
706 cx,
707 "sk-000000000000000000000000000000000000000000000000",
708 )
709 });
710
711 cx.observe(&state, |_, _, cx| {
712 cx.notify();
713 })
714 .detach();
715
716 let load_credentials_task = Some(cx.spawn_in(window, {
717 let state = state.clone();
718 async move |this, cx| {
719 if let Some(task) = state
720 .update(cx, |state, cx| state.authenticate(cx))
721 .log_err()
722 {
723 // We don't log an error, because "not signed in" is also an error.
724 let _ = task.await;
725 }
726 this.update(cx, |this, cx| {
727 this.load_credentials_task = None;
728 cx.notify();
729 })
730 .log_err();
731 }
732 }));
733
734 Self {
735 api_key_editor,
736 state,
737 load_credentials_task,
738 }
739 }
740
741 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
742 let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
743 if api_key.is_empty() {
744 return;
745 }
746
747 let state = self.state.clone();
748 cx.spawn_in(window, async move |_, cx| {
749 state
750 .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
751 .await
752 })
753 .detach_and_log_err(cx);
754 }
755
756 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
757 self.api_key_editor.update(cx, |input, cx| {
758 input.editor.update(cx, |editor, cx| {
759 editor.set_text("", window, cx);
760 });
761 });
762
763 let state = self.state.clone();
764 cx.spawn_in(window, async move |_, cx| {
765 state
766 .update(cx, |state, cx| state.set_api_key(None, cx))?
767 .await
768 })
769 .detach_and_log_err(cx);
770 }
771
772 fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
773 !self.state.read(cx).is_authenticated()
774 }
775}
776
777impl Render for ConfigurationView {
778 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
779 let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
780
781 let api_key_section = if self.should_render_editor(cx) {
782 v_flex()
783 .on_action(cx.listener(Self::save_api_key))
784 .child(Label::new("To use Zed's agent with OpenAI, you need to add an API key. Follow these steps:"))
785 .child(
786 List::new()
787 .child(InstructionListItem::new(
788 "Create one by visiting",
789 Some("OpenAI's console"),
790 Some("https://platform.openai.com/api-keys"),
791 ))
792 .child(InstructionListItem::text_only(
793 "Ensure your OpenAI account has credits",
794 ))
795 .child(InstructionListItem::text_only(
796 "Paste your API key below and hit enter to start using the assistant",
797 )),
798 )
799 .child(self.api_key_editor.clone())
800 .child(
801 Label::new(format!(
802 "You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
803 ))
804 .size(LabelSize::Small)
805 .color(Color::Muted),
806 )
807 .child(
808 Label::new(
809 "Note that having a subscription for another service like GitHub Copilot won't work.",
810 )
811 .size(LabelSize::Small).color(Color::Muted),
812 )
813 .into_any()
814 } else {
815 h_flex()
816 .mt_1()
817 .p_1()
818 .justify_between()
819 .rounded_md()
820 .border_1()
821 .border_color(cx.theme().colors().border)
822 .bg(cx.theme().colors().background)
823 .child(
824 h_flex()
825 .gap_1()
826 .child(Icon::new(IconName::Check).color(Color::Success))
827 .child(Label::new(if env_var_set {
828 format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.")
829 } else {
830 "API key configured.".to_string()
831 })),
832 )
833 .child(
834 Button::new("reset-api-key", "Reset API Key")
835 .label_size(LabelSize::Small)
836 .icon(IconName::Undo)
837 .icon_size(IconSize::Small)
838 .icon_position(IconPosition::Start)
839 .layer(ElevationIndex::ModalSurface)
840 .when(env_var_set, |this| {
841 this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
842 })
843 .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
844 )
845 .into_any()
846 };
847
848 let compatible_api_section = h_flex()
849 .mt_1p5()
850 .gap_0p5()
851 .flex_wrap()
852 .when(self.should_render_editor(cx), |this| {
853 this.pt_1p5()
854 .border_t_1()
855 .border_color(cx.theme().colors().border_variant)
856 })
857 .child(
858 h_flex()
859 .gap_2()
860 .child(
861 Icon::new(IconName::Info)
862 .size(IconSize::XSmall)
863 .color(Color::Muted),
864 )
865 .child(Label::new("Zed also supports OpenAI-compatible models.")),
866 )
867 .child(
868 Button::new("docs", "Learn More")
869 .icon(IconName::ArrowUpRight)
870 .icon_size(IconSize::Small)
871 .icon_color(Color::Muted)
872 .on_click(move |_, _window, cx| {
873 cx.open_url("https://zed.dev/docs/ai/llm-providers#openai-api-compatible")
874 }),
875 );
876
877 if self.load_credentials_task.is_some() {
878 div().child(Label::new("Loading credentials…")).into_any()
879 } else {
880 v_flex()
881 .size_full()
882 .child(api_key_section)
883 .child(compatible_api_section)
884 .into_any()
885 }
886 }
887}
888
889#[cfg(test)]
890mod tests {
891 use gpui::TestAppContext;
892 use language_model::LanguageModelRequestMessage;
893
894 use super::*;
895
896 #[gpui::test]
897 fn tiktoken_rs_support(cx: &TestAppContext) {
898 let request = LanguageModelRequest {
899 thread_id: None,
900 prompt_id: None,
901 intent: None,
902 mode: None,
903 messages: vec![LanguageModelRequestMessage {
904 role: Role::User,
905 content: vec![MessageContent::Text("message".into())],
906 cache: false,
907 }],
908 tools: vec![],
909 tool_choice: None,
910 stop: vec![],
911 temperature: None,
912 thinking_allowed: true,
913 };
914
915 // Validate that all models are supported by tiktoken-rs
916 for model in Model::iter() {
917 let count = cx
918 .executor()
919 .block(count_open_ai_tokens(
920 request.clone(),
921 model,
922 &cx.app.borrow(),
923 ))
924 .unwrap();
925 assert!(count > 0);
926 }
927 }
928}