1use std::collections::HashMap;
2use std::sync::Mutex;
3
4use serde::{Deserialize, Serialize};
5use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
6use zed_extension_api::{self as zed, *};
7
8struct OpenAiProvider {
9 streams: Mutex<HashMap<String, StreamState>>,
10 next_stream_id: Mutex<u64>,
11}
12
13struct StreamState {
14 response_stream: Option<HttpResponseStream>,
15 buffer: String,
16 started: bool,
17 tool_calls: HashMap<usize, AccumulatedToolCall>,
18 tool_calls_emitted: bool,
19}
20
21#[derive(Clone, Default)]
22struct AccumulatedToolCall {
23 id: String,
24 name: String,
25 arguments: String,
26}
27
28struct ModelDefinition {
29 real_id: &'static str,
30 display_name: &'static str,
31 max_tokens: u64,
32 max_output_tokens: Option<u64>,
33 supports_images: bool,
34 is_default: bool,
35 is_default_fast: bool,
36}
37
38const MODELS: &[ModelDefinition] = &[
39 ModelDefinition {
40 real_id: "gpt-4o",
41 display_name: "GPT-4o",
42 max_tokens: 128_000,
43 max_output_tokens: Some(16_384),
44 supports_images: true,
45 is_default: true,
46 is_default_fast: false,
47 },
48 ModelDefinition {
49 real_id: "gpt-4o-mini",
50 display_name: "GPT-4o-mini",
51 max_tokens: 128_000,
52 max_output_tokens: Some(16_384),
53 supports_images: true,
54 is_default: false,
55 is_default_fast: true,
56 },
57 ModelDefinition {
58 real_id: "gpt-4.1",
59 display_name: "GPT-4.1",
60 max_tokens: 1_047_576,
61 max_output_tokens: Some(32_768),
62 supports_images: true,
63 is_default: false,
64 is_default_fast: false,
65 },
66 ModelDefinition {
67 real_id: "gpt-4.1-mini",
68 display_name: "GPT-4.1-mini",
69 max_tokens: 1_047_576,
70 max_output_tokens: Some(32_768),
71 supports_images: true,
72 is_default: false,
73 is_default_fast: false,
74 },
75 ModelDefinition {
76 real_id: "gpt-4.1-nano",
77 display_name: "GPT-4.1-nano",
78 max_tokens: 1_047_576,
79 max_output_tokens: Some(32_768),
80 supports_images: true,
81 is_default: false,
82 is_default_fast: false,
83 },
84 ModelDefinition {
85 real_id: "gpt-5",
86 display_name: "GPT-5",
87 max_tokens: 272_000,
88 max_output_tokens: Some(32_768),
89 supports_images: true,
90 is_default: false,
91 is_default_fast: false,
92 },
93 ModelDefinition {
94 real_id: "gpt-5-mini",
95 display_name: "GPT-5-mini",
96 max_tokens: 272_000,
97 max_output_tokens: Some(32_768),
98 supports_images: true,
99 is_default: false,
100 is_default_fast: false,
101 },
102 ModelDefinition {
103 real_id: "o1",
104 display_name: "o1",
105 max_tokens: 200_000,
106 max_output_tokens: Some(100_000),
107 supports_images: true,
108 is_default: false,
109 is_default_fast: false,
110 },
111 ModelDefinition {
112 real_id: "o3",
113 display_name: "o3",
114 max_tokens: 200_000,
115 max_output_tokens: Some(100_000),
116 supports_images: true,
117 is_default: false,
118 is_default_fast: false,
119 },
120 ModelDefinition {
121 real_id: "o3-mini",
122 display_name: "o3-mini",
123 max_tokens: 200_000,
124 max_output_tokens: Some(100_000),
125 supports_images: false,
126 is_default: false,
127 is_default_fast: false,
128 },
129 ModelDefinition {
130 real_id: "o4-mini",
131 display_name: "o4-mini",
132 max_tokens: 200_000,
133 max_output_tokens: Some(100_000),
134 supports_images: true,
135 is_default: false,
136 is_default_fast: false,
137 },
138];
139
140fn get_real_model_id(display_name: &str) -> Option<&'static str> {
141 MODELS
142 .iter()
143 .find(|m| m.display_name == display_name)
144 .map(|m| m.real_id)
145}
146
147#[derive(Serialize)]
148struct OpenAiRequest {
149 model: String,
150 messages: Vec<OpenAiMessage>,
151 #[serde(skip_serializing_if = "Option::is_none")]
152 tools: Option<Vec<OpenAiTool>>,
153 #[serde(skip_serializing_if = "Option::is_none")]
154 tool_choice: Option<String>,
155 #[serde(skip_serializing_if = "Option::is_none")]
156 temperature: Option<f32>,
157 #[serde(skip_serializing_if = "Option::is_none")]
158 max_tokens: Option<u64>,
159 #[serde(skip_serializing_if = "Vec::is_empty")]
160 stop: Vec<String>,
161 stream: bool,
162 stream_options: Option<StreamOptions>,
163}
164
165#[derive(Serialize)]
166struct StreamOptions {
167 include_usage: bool,
168}
169
170#[derive(Serialize)]
171#[serde(tag = "role")]
172enum OpenAiMessage {
173 #[serde(rename = "system")]
174 System { content: String },
175 #[serde(rename = "user")]
176 User { content: Vec<OpenAiContentPart> },
177 #[serde(rename = "assistant")]
178 Assistant {
179 #[serde(skip_serializing_if = "Option::is_none")]
180 content: Option<String>,
181 #[serde(skip_serializing_if = "Option::is_none")]
182 tool_calls: Option<Vec<OpenAiToolCall>>,
183 },
184 #[serde(rename = "tool")]
185 Tool {
186 tool_call_id: String,
187 content: String,
188 },
189}
190
191#[derive(Serialize)]
192#[serde(tag = "type")]
193enum OpenAiContentPart {
194 #[serde(rename = "text")]
195 Text { text: String },
196 #[serde(rename = "image_url")]
197 ImageUrl { image_url: ImageUrl },
198}
199
200#[derive(Serialize)]
201struct ImageUrl {
202 url: String,
203}
204
205#[derive(Serialize, Deserialize, Clone)]
206struct OpenAiToolCall {
207 id: String,
208 #[serde(rename = "type")]
209 call_type: String,
210 function: OpenAiFunctionCall,
211}
212
213#[derive(Serialize, Deserialize, Clone)]
214struct OpenAiFunctionCall {
215 name: String,
216 arguments: String,
217}
218
219#[derive(Serialize)]
220struct OpenAiTool {
221 #[serde(rename = "type")]
222 tool_type: String,
223 function: OpenAiFunctionDef,
224}
225
226#[derive(Serialize)]
227struct OpenAiFunctionDef {
228 name: String,
229 description: String,
230 parameters: serde_json::Value,
231}
232
233#[derive(Deserialize, Debug)]
234struct OpenAiStreamEvent {
235 choices: Vec<OpenAiChoice>,
236 #[serde(default)]
237 usage: Option<OpenAiUsage>,
238}
239
240#[derive(Deserialize, Debug)]
241struct OpenAiChoice {
242 delta: OpenAiDelta,
243 finish_reason: Option<String>,
244}
245
246#[derive(Deserialize, Debug, Default)]
247struct OpenAiDelta {
248 #[serde(default)]
249 content: Option<String>,
250 #[serde(default)]
251 tool_calls: Option<Vec<OpenAiToolCallDelta>>,
252}
253
254#[derive(Deserialize, Debug)]
255struct OpenAiToolCallDelta {
256 index: usize,
257 #[serde(default)]
258 id: Option<String>,
259 #[serde(default)]
260 function: Option<OpenAiFunctionDelta>,
261}
262
263#[derive(Deserialize, Debug)]
264struct OpenAiFunctionDelta {
265 #[serde(default)]
266 name: Option<String>,
267 #[serde(default)]
268 arguments: Option<String>,
269}
270
271#[derive(Deserialize, Debug)]
272struct OpenAiUsage {
273 prompt_tokens: u64,
274 completion_tokens: u64,
275}
276
277#[allow(dead_code)]
278#[derive(Deserialize, Debug)]
279struct OpenAiError {
280 error: OpenAiErrorDetail,
281}
282
283#[allow(dead_code)]
284#[derive(Deserialize, Debug)]
285struct OpenAiErrorDetail {
286 message: String,
287}
288
289fn convert_request(
290 model_id: &str,
291 request: &LlmCompletionRequest,
292) -> Result<OpenAiRequest, String> {
293 let real_model_id =
294 get_real_model_id(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?;
295
296 let mut messages = Vec::new();
297
298 for msg in &request.messages {
299 match msg.role {
300 LlmMessageRole::System => {
301 let text: String = msg
302 .content
303 .iter()
304 .filter_map(|c| match c {
305 LlmMessageContent::Text(t) => Some(t.as_str()),
306 _ => None,
307 })
308 .collect::<Vec<_>>()
309 .join("\n");
310 if !text.is_empty() {
311 messages.push(OpenAiMessage::System { content: text });
312 }
313 }
314 LlmMessageRole::User => {
315 let parts: Vec<OpenAiContentPart> = msg
316 .content
317 .iter()
318 .filter_map(|c| match c {
319 LlmMessageContent::Text(t) => {
320 Some(OpenAiContentPart::Text { text: t.clone() })
321 }
322 LlmMessageContent::Image(img) => Some(OpenAiContentPart::ImageUrl {
323 image_url: ImageUrl {
324 url: format!("data:image/png;base64,{}", img.source),
325 },
326 }),
327 LlmMessageContent::ToolResult(_) => None,
328 _ => None,
329 })
330 .collect();
331
332 for content in &msg.content {
333 if let LlmMessageContent::ToolResult(result) = content {
334 let content_text = match &result.content {
335 LlmToolResultContent::Text(t) => t.clone(),
336 LlmToolResultContent::Image(_) => "[Image]".to_string(),
337 };
338 messages.push(OpenAiMessage::Tool {
339 tool_call_id: result.tool_use_id.clone(),
340 content: content_text,
341 });
342 }
343 }
344
345 if !parts.is_empty() {
346 messages.push(OpenAiMessage::User { content: parts });
347 }
348 }
349 LlmMessageRole::Assistant => {
350 let mut content_text: Option<String> = None;
351 let mut tool_calls: Vec<OpenAiToolCall> = Vec::new();
352
353 for c in &msg.content {
354 match c {
355 LlmMessageContent::Text(t) => {
356 content_text = Some(t.clone());
357 }
358 LlmMessageContent::ToolUse(tool_use) => {
359 tool_calls.push(OpenAiToolCall {
360 id: tool_use.id.clone(),
361 call_type: "function".to_string(),
362 function: OpenAiFunctionCall {
363 name: tool_use.name.clone(),
364 arguments: tool_use.input.clone(),
365 },
366 });
367 }
368 _ => {}
369 }
370 }
371
372 messages.push(OpenAiMessage::Assistant {
373 content: content_text,
374 tool_calls: if tool_calls.is_empty() {
375 None
376 } else {
377 Some(tool_calls)
378 },
379 });
380 }
381 }
382 }
383
384 let tools: Option<Vec<OpenAiTool>> = if request.tools.is_empty() {
385 None
386 } else {
387 Some(
388 request
389 .tools
390 .iter()
391 .map(|t| OpenAiTool {
392 tool_type: "function".to_string(),
393 function: OpenAiFunctionDef {
394 name: t.name.clone(),
395 description: t.description.clone(),
396 parameters: serde_json::from_str(&t.input_schema)
397 .unwrap_or(serde_json::Value::Object(Default::default())),
398 },
399 })
400 .collect(),
401 )
402 };
403
404 let tool_choice = request.tool_choice.as_ref().map(|tc| match tc {
405 LlmToolChoice::Auto => "auto".to_string(),
406 LlmToolChoice::Any => "required".to_string(),
407 LlmToolChoice::None => "none".to_string(),
408 });
409
410 Ok(OpenAiRequest {
411 model: real_model_id.to_string(),
412 messages,
413 tools,
414 tool_choice,
415 temperature: request.temperature,
416 max_tokens: request.max_tokens,
417 stop: request.stop_sequences.clone(),
418 stream: true,
419 stream_options: Some(StreamOptions {
420 include_usage: true,
421 }),
422 })
423}
424
425fn parse_sse_line(line: &str) -> Option<OpenAiStreamEvent> {
426 if let Some(data) = line.strip_prefix("data: ") {
427 if data == "[DONE]" {
428 return None;
429 }
430 serde_json::from_str(data).ok()
431 } else {
432 None
433 }
434}
435
436impl zed::Extension for OpenAiProvider {
437 fn new() -> Self {
438 Self {
439 streams: Mutex::new(HashMap::new()),
440 next_stream_id: Mutex::new(0),
441 }
442 }
443
444 fn llm_providers(&self) -> Vec<LlmProviderInfo> {
445 vec![LlmProviderInfo {
446 id: "openai".into(),
447 name: "OpenAI".into(),
448 icon: Some("icons/openai.svg".into()),
449 }]
450 }
451
452 fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
453 Ok(MODELS
454 .iter()
455 .map(|m| LlmModelInfo {
456 id: m.display_name.to_string(),
457 name: m.display_name.to_string(),
458 max_token_count: m.max_tokens,
459 max_output_tokens: m.max_output_tokens,
460 capabilities: LlmModelCapabilities {
461 supports_images: m.supports_images,
462 supports_tools: true,
463 supports_tool_choice_auto: true,
464 supports_tool_choice_any: true,
465 supports_tool_choice_none: true,
466 supports_thinking: false,
467 tool_input_format: LlmToolInputFormat::JsonSchema,
468 },
469 is_default: m.is_default,
470 is_default_fast: m.is_default_fast,
471 })
472 .collect())
473 }
474
475 fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
476 llm_get_credential("openai").is_some()
477 }
478
479 fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
480 Some(
481 "[Create an API key](https://platform.openai.com/api-keys) to use OpenAI as your LLM provider.".to_string(),
482 )
483 }
484
485 fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
486 llm_delete_credential("openai")
487 }
488
489 fn llm_stream_completion_start(
490 &mut self,
491 _provider_id: &str,
492 model_id: &str,
493 request: &LlmCompletionRequest,
494 ) -> Result<String, String> {
495 let api_key = llm_get_credential("openai").ok_or_else(|| {
496 "No API key configured. Please add your OpenAI API key in settings.".to_string()
497 })?;
498
499 let openai_request = convert_request(model_id, request)?;
500
501 let body = serde_json::to_vec(&openai_request)
502 .map_err(|e| format!("Failed to serialize request: {}", e))?;
503
504 let http_request = HttpRequest {
505 method: HttpMethod::Post,
506 url: "https://api.openai.com/v1/chat/completions".to_string(),
507 headers: vec![
508 ("Content-Type".to_string(), "application/json".to_string()),
509 ("Authorization".to_string(), format!("Bearer {}", api_key)),
510 ],
511 body: Some(body),
512 redirect_policy: RedirectPolicy::FollowAll,
513 };
514
515 let response_stream = http_request
516 .fetch_stream()
517 .map_err(|e| format!("HTTP request failed: {}", e))?;
518
519 let stream_id = {
520 let mut id_counter = self.next_stream_id.lock().unwrap();
521 let id = format!("openai-stream-{}", *id_counter);
522 *id_counter += 1;
523 id
524 };
525
526 self.streams.lock().unwrap().insert(
527 stream_id.clone(),
528 StreamState {
529 response_stream: Some(response_stream),
530 buffer: String::new(),
531 started: false,
532 tool_calls: HashMap::new(),
533 tool_calls_emitted: false,
534 },
535 );
536
537 Ok(stream_id)
538 }
539
540 fn llm_stream_completion_next(
541 &mut self,
542 stream_id: &str,
543 ) -> Result<Option<LlmCompletionEvent>, String> {
544 let mut streams = self.streams.lock().unwrap();
545 let state = streams
546 .get_mut(stream_id)
547 .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
548
549 if !state.started {
550 state.started = true;
551 return Ok(Some(LlmCompletionEvent::Started));
552 }
553
554 let response_stream = state
555 .response_stream
556 .as_mut()
557 .ok_or_else(|| "Stream already closed".to_string())?;
558
559 loop {
560 if let Some(newline_pos) = state.buffer.find('\n') {
561 let line = state.buffer[..newline_pos].trim().to_string();
562 state.buffer = state.buffer[newline_pos + 1..].to_string();
563
564 if line.is_empty() {
565 continue;
566 }
567
568 if let Some(event) = parse_sse_line(&line) {
569 if let Some(choice) = event.choices.first() {
570 if let Some(tool_calls) = &choice.delta.tool_calls {
571 for tc in tool_calls {
572 let entry = state.tool_calls.entry(tc.index).or_default();
573
574 if let Some(id) = &tc.id {
575 entry.id = id.clone();
576 }
577
578 if let Some(func) = &tc.function {
579 if let Some(name) = &func.name {
580 entry.name = name.clone();
581 }
582 if let Some(args) = &func.arguments {
583 entry.arguments.push_str(args);
584 }
585 }
586 }
587 }
588
589 if let Some(reason) = &choice.finish_reason {
590 if reason == "tool_calls" && !state.tool_calls_emitted {
591 state.tool_calls_emitted = true;
592 if let Some((&index, _)) = state.tool_calls.iter().next() {
593 if let Some(tool_call) = state.tool_calls.remove(&index) {
594 return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
595 id: tool_call.id,
596 name: tool_call.name,
597 input: tool_call.arguments,
598 is_input_complete: true,
599 thought_signature: None,
600 })));
601 }
602 }
603 }
604
605 let stop_reason = match reason.as_str() {
606 "stop" => LlmStopReason::EndTurn,
607 "length" => LlmStopReason::MaxTokens,
608 "tool_calls" => LlmStopReason::ToolUse,
609 "content_filter" => LlmStopReason::Refusal,
610 _ => LlmStopReason::EndTurn,
611 };
612
613 if let Some(usage) = event.usage {
614 return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
615 input_tokens: usage.prompt_tokens,
616 output_tokens: usage.completion_tokens,
617 cache_creation_input_tokens: None,
618 cache_read_input_tokens: None,
619 })));
620 }
621
622 return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
623 }
624
625 if let Some(content) = &choice.delta.content {
626 if !content.is_empty() {
627 return Ok(Some(LlmCompletionEvent::Text(content.clone())));
628 }
629 }
630 }
631
632 if event.choices.is_empty() {
633 if let Some(usage) = event.usage {
634 return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
635 input_tokens: usage.prompt_tokens,
636 output_tokens: usage.completion_tokens,
637 cache_creation_input_tokens: None,
638 cache_read_input_tokens: None,
639 })));
640 }
641 }
642 }
643
644 continue;
645 }
646
647 match response_stream.next_chunk() {
648 Ok(Some(chunk)) => {
649 let text = String::from_utf8_lossy(&chunk);
650 state.buffer.push_str(&text);
651 }
652 Ok(None) => {
653 if !state.tool_calls.is_empty() && !state.tool_calls_emitted {
654 state.tool_calls_emitted = true;
655 let keys: Vec<usize> = state.tool_calls.keys().copied().collect();
656 if let Some(&key) = keys.first() {
657 if let Some(tool_call) = state.tool_calls.remove(&key) {
658 return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
659 id: tool_call.id,
660 name: tool_call.name,
661 input: tool_call.arguments,
662 is_input_complete: true,
663 thought_signature: None,
664 })));
665 }
666 }
667 }
668 return Ok(None);
669 }
670 Err(e) => {
671 return Err(format!("Stream error: {}", e));
672 }
673 }
674 }
675 }
676
677 fn llm_stream_completion_close(&mut self, stream_id: &str) {
678 self.streams.lock().unwrap().remove(stream_id);
679 }
680}
681
682zed::register_extension!(OpenAiProvider);