1use std::collections::HashMap;
2use std::sync::Mutex;
3
4use serde::{Deserialize, Serialize};
5use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
6use zed_extension_api::{self as zed, *};
7
8struct OpenAiProvider {
9 streams: Mutex<HashMap<String, StreamState>>,
10 next_stream_id: Mutex<u64>,
11}
12
13struct StreamState {
14 response_stream: Option<HttpResponseStream>,
15 buffer: String,
16 started: bool,
17 tool_calls: HashMap<usize, AccumulatedToolCall>,
18 tool_calls_emitted: bool,
19}
20
21#[derive(Clone, Default)]
22struct AccumulatedToolCall {
23 id: String,
24 name: String,
25 arguments: String,
26}
27
28struct ModelDefinition {
29 real_id: &'static str,
30 display_name: &'static str,
31 max_tokens: u64,
32 max_output_tokens: Option<u64>,
33 supports_images: bool,
34 is_default: bool,
35 is_default_fast: bool,
36}
37
38const MODELS: &[ModelDefinition] = &[
39 ModelDefinition {
40 real_id: "gpt-4o",
41 display_name: "GPT-4o",
42 max_tokens: 128_000,
43 max_output_tokens: Some(16_384),
44 supports_images: true,
45 is_default: true,
46 is_default_fast: false,
47 },
48 ModelDefinition {
49 real_id: "gpt-4o-mini",
50 display_name: "GPT-4o-mini",
51 max_tokens: 128_000,
52 max_output_tokens: Some(16_384),
53 supports_images: true,
54 is_default: false,
55 is_default_fast: true,
56 },
57 ModelDefinition {
58 real_id: "gpt-4.1",
59 display_name: "GPT-4.1",
60 max_tokens: 1_047_576,
61 max_output_tokens: Some(32_768),
62 supports_images: true,
63 is_default: false,
64 is_default_fast: false,
65 },
66 ModelDefinition {
67 real_id: "gpt-4.1-mini",
68 display_name: "GPT-4.1-mini",
69 max_tokens: 1_047_576,
70 max_output_tokens: Some(32_768),
71 supports_images: true,
72 is_default: false,
73 is_default_fast: false,
74 },
75 ModelDefinition {
76 real_id: "gpt-4.1-nano",
77 display_name: "GPT-4.1-nano",
78 max_tokens: 1_047_576,
79 max_output_tokens: Some(32_768),
80 supports_images: true,
81 is_default: false,
82 is_default_fast: false,
83 },
84 ModelDefinition {
85 real_id: "gpt-5",
86 display_name: "GPT-5",
87 max_tokens: 272_000,
88 max_output_tokens: Some(32_768),
89 supports_images: true,
90 is_default: false,
91 is_default_fast: false,
92 },
93 ModelDefinition {
94 real_id: "gpt-5-mini",
95 display_name: "GPT-5-mini",
96 max_tokens: 272_000,
97 max_output_tokens: Some(32_768),
98 supports_images: true,
99 is_default: false,
100 is_default_fast: false,
101 },
102 ModelDefinition {
103 real_id: "o1",
104 display_name: "o1",
105 max_tokens: 200_000,
106 max_output_tokens: Some(100_000),
107 supports_images: true,
108 is_default: false,
109 is_default_fast: false,
110 },
111 ModelDefinition {
112 real_id: "o3",
113 display_name: "o3",
114 max_tokens: 200_000,
115 max_output_tokens: Some(100_000),
116 supports_images: true,
117 is_default: false,
118 is_default_fast: false,
119 },
120 ModelDefinition {
121 real_id: "o3-mini",
122 display_name: "o3-mini",
123 max_tokens: 200_000,
124 max_output_tokens: Some(100_000),
125 supports_images: false,
126 is_default: false,
127 is_default_fast: false,
128 },
129 ModelDefinition {
130 real_id: "o4-mini",
131 display_name: "o4-mini",
132 max_tokens: 200_000,
133 max_output_tokens: Some(100_000),
134 supports_images: true,
135 is_default: false,
136 is_default_fast: false,
137 },
138];
139
140fn get_real_model_id(display_name: &str) -> Option<&'static str> {
141 MODELS
142 .iter()
143 .find(|m| m.display_name == display_name)
144 .map(|m| m.real_id)
145}
146
147#[derive(Serialize)]
148struct OpenAiRequest {
149 model: String,
150 messages: Vec<OpenAiMessage>,
151 #[serde(skip_serializing_if = "Option::is_none")]
152 tools: Option<Vec<OpenAiTool>>,
153 #[serde(skip_serializing_if = "Option::is_none")]
154 tool_choice: Option<String>,
155 #[serde(skip_serializing_if = "Option::is_none")]
156 temperature: Option<f32>,
157 #[serde(skip_serializing_if = "Option::is_none")]
158 max_tokens: Option<u64>,
159 #[serde(skip_serializing_if = "Vec::is_empty")]
160 stop: Vec<String>,
161 stream: bool,
162 stream_options: Option<StreamOptions>,
163}
164
165#[derive(Serialize)]
166struct StreamOptions {
167 include_usage: bool,
168}
169
170#[derive(Serialize)]
171#[serde(tag = "role")]
172enum OpenAiMessage {
173 #[serde(rename = "system")]
174 System { content: String },
175 #[serde(rename = "user")]
176 User { content: Vec<OpenAiContentPart> },
177 #[serde(rename = "assistant")]
178 Assistant {
179 #[serde(skip_serializing_if = "Option::is_none")]
180 content: Option<String>,
181 #[serde(skip_serializing_if = "Option::is_none")]
182 tool_calls: Option<Vec<OpenAiToolCall>>,
183 },
184 #[serde(rename = "tool")]
185 Tool {
186 tool_call_id: String,
187 content: String,
188 },
189}
190
191#[derive(Serialize)]
192#[serde(tag = "type")]
193enum OpenAiContentPart {
194 #[serde(rename = "text")]
195 Text { text: String },
196 #[serde(rename = "image_url")]
197 ImageUrl { image_url: ImageUrl },
198}
199
200#[derive(Serialize)]
201struct ImageUrl {
202 url: String,
203}
204
205#[derive(Serialize, Deserialize, Clone)]
206struct OpenAiToolCall {
207 id: String,
208 #[serde(rename = "type")]
209 call_type: String,
210 function: OpenAiFunctionCall,
211}
212
213#[derive(Serialize, Deserialize, Clone)]
214struct OpenAiFunctionCall {
215 name: String,
216 arguments: String,
217}
218
219#[derive(Serialize)]
220struct OpenAiTool {
221 #[serde(rename = "type")]
222 tool_type: String,
223 function: OpenAiFunctionDef,
224}
225
226#[derive(Serialize)]
227struct OpenAiFunctionDef {
228 name: String,
229 description: String,
230 parameters: serde_json::Value,
231}
232
233#[derive(Deserialize, Debug)]
234struct OpenAiStreamEvent {
235 choices: Vec<OpenAiChoice>,
236 #[serde(default)]
237 usage: Option<OpenAiUsage>,
238}
239
240#[derive(Deserialize, Debug)]
241struct OpenAiChoice {
242 delta: OpenAiDelta,
243 finish_reason: Option<String>,
244}
245
246#[derive(Deserialize, Debug, Default)]
247struct OpenAiDelta {
248 #[serde(default)]
249 content: Option<String>,
250 #[serde(default)]
251 tool_calls: Option<Vec<OpenAiToolCallDelta>>,
252}
253
254#[derive(Deserialize, Debug)]
255struct OpenAiToolCallDelta {
256 index: usize,
257 #[serde(default)]
258 id: Option<String>,
259 #[serde(default)]
260 function: Option<OpenAiFunctionDelta>,
261}
262
263#[derive(Deserialize, Debug)]
264struct OpenAiFunctionDelta {
265 #[serde(default)]
266 name: Option<String>,
267 #[serde(default)]
268 arguments: Option<String>,
269}
270
271#[derive(Deserialize, Debug)]
272struct OpenAiUsage {
273 prompt_tokens: u64,
274 completion_tokens: u64,
275}
276
277#[allow(dead_code)]
278#[derive(Deserialize, Debug)]
279struct OpenAiError {
280 error: OpenAiErrorDetail,
281}
282
283#[allow(dead_code)]
284#[derive(Deserialize, Debug)]
285struct OpenAiErrorDetail {
286 message: String,
287}
288
289fn convert_request(
290 model_id: &str,
291 request: &LlmCompletionRequest,
292) -> Result<OpenAiRequest, String> {
293 let real_model_id =
294 get_real_model_id(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?;
295
296 let mut messages = Vec::new();
297
298 for msg in &request.messages {
299 match msg.role {
300 LlmMessageRole::System => {
301 let text: String = msg
302 .content
303 .iter()
304 .filter_map(|c| match c {
305 LlmMessageContent::Text(t) => Some(t.as_str()),
306 _ => None,
307 })
308 .collect::<Vec<_>>()
309 .join("\n");
310 if !text.is_empty() {
311 messages.push(OpenAiMessage::System { content: text });
312 }
313 }
314 LlmMessageRole::User => {
315 let parts: Vec<OpenAiContentPart> = msg
316 .content
317 .iter()
318 .filter_map(|c| match c {
319 LlmMessageContent::Text(t) => {
320 Some(OpenAiContentPart::Text { text: t.clone() })
321 }
322 LlmMessageContent::Image(img) => Some(OpenAiContentPart::ImageUrl {
323 image_url: ImageUrl {
324 url: format!("data:image/png;base64,{}", img.source),
325 },
326 }),
327 LlmMessageContent::ToolResult(_) => None,
328 _ => None,
329 })
330 .collect();
331
332 for content in &msg.content {
333 if let LlmMessageContent::ToolResult(result) = content {
334 let content_text = match &result.content {
335 LlmToolResultContent::Text(t) => t.clone(),
336 LlmToolResultContent::Image(_) => "[Image]".to_string(),
337 };
338 messages.push(OpenAiMessage::Tool {
339 tool_call_id: result.tool_use_id.clone(),
340 content: content_text,
341 });
342 }
343 }
344
345 if !parts.is_empty() {
346 messages.push(OpenAiMessage::User { content: parts });
347 }
348 }
349 LlmMessageRole::Assistant => {
350 let mut content_text: Option<String> = None;
351 let mut tool_calls: Vec<OpenAiToolCall> = Vec::new();
352
353 for c in &msg.content {
354 match c {
355 LlmMessageContent::Text(t) => {
356 content_text = Some(t.clone());
357 }
358 LlmMessageContent::ToolUse(tool_use) => {
359 tool_calls.push(OpenAiToolCall {
360 id: tool_use.id.clone(),
361 call_type: "function".to_string(),
362 function: OpenAiFunctionCall {
363 name: tool_use.name.clone(),
364 arguments: tool_use.input.clone(),
365 },
366 });
367 }
368 _ => {}
369 }
370 }
371
372 messages.push(OpenAiMessage::Assistant {
373 content: content_text,
374 tool_calls: if tool_calls.is_empty() {
375 None
376 } else {
377 Some(tool_calls)
378 },
379 });
380 }
381 }
382 }
383
384 let tools: Option<Vec<OpenAiTool>> = if request.tools.is_empty() {
385 None
386 } else {
387 Some(
388 request
389 .tools
390 .iter()
391 .map(|t| OpenAiTool {
392 tool_type: "function".to_string(),
393 function: OpenAiFunctionDef {
394 name: t.name.clone(),
395 description: t.description.clone(),
396 parameters: serde_json::from_str(&t.input_schema)
397 .unwrap_or(serde_json::Value::Object(Default::default())),
398 },
399 })
400 .collect(),
401 )
402 };
403
404 let tool_choice = request.tool_choice.as_ref().map(|tc| match tc {
405 LlmToolChoice::Auto => "auto".to_string(),
406 LlmToolChoice::Any => "required".to_string(),
407 LlmToolChoice::None => "none".to_string(),
408 });
409
410 Ok(OpenAiRequest {
411 model: real_model_id.to_string(),
412 messages,
413 tools,
414 tool_choice,
415 temperature: request.temperature,
416 max_tokens: request.max_tokens,
417 stop: request.stop_sequences.clone(),
418 stream: true,
419 stream_options: Some(StreamOptions {
420 include_usage: true,
421 }),
422 })
423}
424
425fn parse_sse_line(line: &str) -> Option<OpenAiStreamEvent> {
426 if let Some(data) = line.strip_prefix("data: ") {
427 if data == "[DONE]" {
428 return None;
429 }
430 serde_json::from_str(data).ok()
431 } else {
432 None
433 }
434}
435
436impl zed::Extension for OpenAiProvider {
437 fn new() -> Self {
438 Self {
439 streams: Mutex::new(HashMap::new()),
440 next_stream_id: Mutex::new(0),
441 }
442 }
443
444 fn llm_providers(&self) -> Vec<LlmProviderInfo> {
445 vec![LlmProviderInfo {
446 id: "openai".into(),
447 name: "OpenAI".into(),
448 icon: Some("icons/openai.svg".into()),
449 }]
450 }
451
452 fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
453 Ok(MODELS
454 .iter()
455 .map(|m| LlmModelInfo {
456 id: m.display_name.to_string(),
457 name: m.display_name.to_string(),
458 max_token_count: m.max_tokens,
459 max_output_tokens: m.max_output_tokens,
460 capabilities: LlmModelCapabilities {
461 supports_images: m.supports_images,
462 supports_tools: true,
463 supports_tool_choice_auto: true,
464 supports_tool_choice_any: true,
465 supports_tool_choice_none: true,
466 supports_thinking: false,
467 tool_input_format: LlmToolInputFormat::JsonSchema,
468 },
469 is_default: m.is_default,
470 is_default_fast: m.is_default_fast,
471 })
472 .collect())
473 }
474
475 fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
476 llm_get_credential("openai").is_some()
477 }
478
479 fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
480 Some(
481 r#"# OpenAI Setup
482
483Welcome to **OpenAI**! This extension provides access to OpenAI GPT models.
484
485## Configuration
486
487Enter your OpenAI API key below. You can find your API key at [platform.openai.com/api-keys](https://platform.openai.com/api-keys).
488
489## Available Models
490
491| Display Name | Real Model | Context | Output |
492|--------------|------------|---------|--------|
493| GPT-4o | gpt-4o | 128K | 16K |
494| GPT-4o-mini | gpt-4o-mini | 128K | 16K |
495| GPT-4.1 | gpt-4.1 | 1M | 32K |
496| GPT-4.1-mini | gpt-4.1-mini | 1M | 32K |
497| GPT-5 | gpt-5 | 272K | 32K |
498| GPT-5-mini | gpt-5-mini | 272K | 32K |
499| o1 | o1 | 200K | 100K |
500| o3 | o3 | 200K | 100K |
501| o3-mini | o3-mini | 200K | 100K |
502
503## Features
504
505- ✅ Full streaming support
506- ✅ Tool/function calling
507- ✅ Vision (image inputs)
508- ✅ All OpenAI models
509
510## Pricing
511
512Uses your OpenAI API credits. See [OpenAI pricing](https://openai.com/pricing) for details.
513"#
514 .to_string(),
515 )
516 }
517
518 fn llm_provider_authenticate(&mut self, _provider_id: &str) -> Result<(), String> {
519 let provided = llm_request_credential(
520 "openai",
521 LlmCredentialType::ApiKey,
522 "OpenAI API Key",
523 "sk-...",
524 )?;
525 if provided {
526 Ok(())
527 } else {
528 Err("Authentication cancelled".to_string())
529 }
530 }
531
532 fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
533 llm_delete_credential("openai")
534 }
535
536 fn llm_stream_completion_start(
537 &mut self,
538 _provider_id: &str,
539 model_id: &str,
540 request: &LlmCompletionRequest,
541 ) -> Result<String, String> {
542 let api_key = llm_get_credential("openai").ok_or_else(|| {
543 "No API key configured. Please add your OpenAI API key in settings.".to_string()
544 })?;
545
546 let openai_request = convert_request(model_id, request)?;
547
548 let body = serde_json::to_vec(&openai_request)
549 .map_err(|e| format!("Failed to serialize request: {}", e))?;
550
551 let http_request = HttpRequest {
552 method: HttpMethod::Post,
553 url: "https://api.openai.com/v1/chat/completions".to_string(),
554 headers: vec![
555 ("Content-Type".to_string(), "application/json".to_string()),
556 ("Authorization".to_string(), format!("Bearer {}", api_key)),
557 ],
558 body: Some(body),
559 redirect_policy: RedirectPolicy::FollowAll,
560 };
561
562 let response_stream = http_request
563 .fetch_stream()
564 .map_err(|e| format!("HTTP request failed: {}", e))?;
565
566 let stream_id = {
567 let mut id_counter = self.next_stream_id.lock().unwrap();
568 let id = format!("openai-stream-{}", *id_counter);
569 *id_counter += 1;
570 id
571 };
572
573 self.streams.lock().unwrap().insert(
574 stream_id.clone(),
575 StreamState {
576 response_stream: Some(response_stream),
577 buffer: String::new(),
578 started: false,
579 tool_calls: HashMap::new(),
580 tool_calls_emitted: false,
581 },
582 );
583
584 Ok(stream_id)
585 }
586
587 fn llm_stream_completion_next(
588 &mut self,
589 stream_id: &str,
590 ) -> Result<Option<LlmCompletionEvent>, String> {
591 let mut streams = self.streams.lock().unwrap();
592 let state = streams
593 .get_mut(stream_id)
594 .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
595
596 if !state.started {
597 state.started = true;
598 return Ok(Some(LlmCompletionEvent::Started));
599 }
600
601 let response_stream = state
602 .response_stream
603 .as_mut()
604 .ok_or_else(|| "Stream already closed".to_string())?;
605
606 loop {
607 if let Some(newline_pos) = state.buffer.find('\n') {
608 let line = state.buffer[..newline_pos].trim().to_string();
609 state.buffer = state.buffer[newline_pos + 1..].to_string();
610
611 if line.is_empty() {
612 continue;
613 }
614
615 if let Some(event) = parse_sse_line(&line) {
616 if let Some(choice) = event.choices.first() {
617 if let Some(tool_calls) = &choice.delta.tool_calls {
618 for tc in tool_calls {
619 let entry = state.tool_calls.entry(tc.index).or_default();
620
621 if let Some(id) = &tc.id {
622 entry.id = id.clone();
623 }
624
625 if let Some(func) = &tc.function {
626 if let Some(name) = &func.name {
627 entry.name = name.clone();
628 }
629 if let Some(args) = &func.arguments {
630 entry.arguments.push_str(args);
631 }
632 }
633 }
634 }
635
636 if let Some(reason) = &choice.finish_reason {
637 if reason == "tool_calls" && !state.tool_calls_emitted {
638 state.tool_calls_emitted = true;
639 if let Some((&index, _)) = state.tool_calls.iter().next() {
640 if let Some(tool_call) = state.tool_calls.remove(&index) {
641 return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
642 id: tool_call.id,
643 name: tool_call.name,
644 input: tool_call.arguments,
645 thought_signature: None,
646 })));
647 }
648 }
649 }
650
651 let stop_reason = match reason.as_str() {
652 "stop" => LlmStopReason::EndTurn,
653 "length" => LlmStopReason::MaxTokens,
654 "tool_calls" => LlmStopReason::ToolUse,
655 "content_filter" => LlmStopReason::Refusal,
656 _ => LlmStopReason::EndTurn,
657 };
658
659 if let Some(usage) = event.usage {
660 return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
661 input_tokens: usage.prompt_tokens,
662 output_tokens: usage.completion_tokens,
663 cache_creation_input_tokens: None,
664 cache_read_input_tokens: None,
665 })));
666 }
667
668 return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
669 }
670
671 if let Some(content) = &choice.delta.content {
672 if !content.is_empty() {
673 return Ok(Some(LlmCompletionEvent::Text(content.clone())));
674 }
675 }
676 }
677
678 if event.choices.is_empty() {
679 if let Some(usage) = event.usage {
680 return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
681 input_tokens: usage.prompt_tokens,
682 output_tokens: usage.completion_tokens,
683 cache_creation_input_tokens: None,
684 cache_read_input_tokens: None,
685 })));
686 }
687 }
688 }
689
690 continue;
691 }
692
693 match response_stream.next_chunk() {
694 Ok(Some(chunk)) => {
695 let text = String::from_utf8_lossy(&chunk);
696 state.buffer.push_str(&text);
697 }
698 Ok(None) => {
699 if !state.tool_calls.is_empty() && !state.tool_calls_emitted {
700 state.tool_calls_emitted = true;
701 let keys: Vec<usize> = state.tool_calls.keys().copied().collect();
702 if let Some(&key) = keys.first() {
703 if let Some(tool_call) = state.tool_calls.remove(&key) {
704 return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
705 id: tool_call.id,
706 name: tool_call.name,
707 input: tool_call.arguments,
708 thought_signature: None,
709 })));
710 }
711 }
712 }
713 return Ok(None);
714 }
715 Err(e) => {
716 return Err(format!("Stream error: {}", e));
717 }
718 }
719 }
720 }
721
722 fn llm_stream_completion_close(&mut self, stream_id: &str) {
723 self.streams.lock().unwrap().remove(stream_id);
724 }
725}
726
727zed::register_extension!(OpenAiProvider);