1// Package openai provides an implementation of the fantasy AI SDK for OpenAI's language models.
2package openai
3
4import (
5 "cmp"
6 "context"
7 "maps"
8
9 "charm.land/fantasy"
10 "github.com/openai/openai-go/v2"
11 "github.com/openai/openai-go/v2/option"
12)
13
14const (
15 // Name is the name of the OpenAI provider.
16 Name = "openai"
17 // DefaultURL is the default URL for the OpenAI API.
18 DefaultURL = "https://api.openai.com/v1"
19)
20
21type provider struct {
22 options options
23}
24
25type options struct {
26 baseURL string
27 apiKey string
28 organization string
29 project string
30 name string
31 useResponsesAPI bool
32 headers map[string]string
33 client option.HTTPClient
34 sdkOptions []option.RequestOption
35 languageModelOptions []LanguageModelOption
36}
37
38// Option defines a function that configures OpenAI provider options.
39type Option = func(*options)
40
41// New creates a new OpenAI provider with the given options.
42func New(opts ...Option) (fantasy.Provider, error) {
43 providerOptions := options{
44 headers: map[string]string{},
45 languageModelOptions: make([]LanguageModelOption, 0),
46 }
47 for _, o := range opts {
48 o(&providerOptions)
49 }
50
51 providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
52 providerOptions.name = cmp.Or(providerOptions.name, Name)
53
54 if providerOptions.organization != "" {
55 providerOptions.headers["OpenAi-Organization"] = providerOptions.organization
56 }
57 if providerOptions.project != "" {
58 providerOptions.headers["OpenAi-Project"] = providerOptions.project
59 }
60
61 return &provider{options: providerOptions}, nil
62}
63
64// WithBaseURL sets the base URL for the OpenAI provider.
65func WithBaseURL(baseURL string) Option {
66 return func(o *options) {
67 o.baseURL = baseURL
68 }
69}
70
71// WithAPIKey sets the API key for the OpenAI provider.
72func WithAPIKey(apiKey string) Option {
73 return func(o *options) {
74 o.apiKey = apiKey
75 }
76}
77
78// WithOrganization sets the organization for the OpenAI provider.
79func WithOrganization(organization string) Option {
80 return func(o *options) {
81 o.organization = organization
82 }
83}
84
85// WithProject sets the project for the OpenAI provider.
86func WithProject(project string) Option {
87 return func(o *options) {
88 o.project = project
89 }
90}
91
92// WithName sets the name for the OpenAI provider.
93func WithName(name string) Option {
94 return func(o *options) {
95 o.name = name
96 }
97}
98
99// WithHeaders sets the headers for the OpenAI provider.
100func WithHeaders(headers map[string]string) Option {
101 return func(o *options) {
102 maps.Copy(o.headers, headers)
103 }
104}
105
106// WithHTTPClient sets the HTTP client for the OpenAI provider.
107func WithHTTPClient(client option.HTTPClient) Option {
108 return func(o *options) {
109 o.client = client
110 }
111}
112
113// WithSDKOptions sets the SDK options for the OpenAI provider.
114func WithSDKOptions(opts ...option.RequestOption) Option {
115 return func(o *options) {
116 o.sdkOptions = append(o.sdkOptions, opts...)
117 }
118}
119
120// WithLanguageModelOptions sets the language model options for the OpenAI provider.
121func WithLanguageModelOptions(opts ...LanguageModelOption) Option {
122 return func(o *options) {
123 o.languageModelOptions = append(o.languageModelOptions, opts...)
124 }
125}
126
127// WithUseResponsesAPI configures the provider to use the responses API for models that support it.
128func WithUseResponsesAPI() Option {
129 return func(o *options) {
130 o.useResponsesAPI = true
131 }
132}
133
134// LanguageModel implements fantasy.Provider.
135func (o *provider) LanguageModel(_ context.Context, modelID string) (fantasy.LanguageModel, error) {
136 openaiClientOptions := make([]option.RequestOption, 0, 5+len(o.options.headers)+len(o.options.sdkOptions))
137
138 if o.options.apiKey != "" {
139 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
140 }
141 if o.options.baseURL != "" {
142 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
143 }
144
145 for key, value := range o.options.headers {
146 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
147 }
148
149 if o.options.client != nil {
150 openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
151 }
152
153 openaiClientOptions = append(openaiClientOptions, o.options.sdkOptions...)
154
155 client := openai.NewClient(openaiClientOptions...)
156
157 if o.options.useResponsesAPI && IsResponsesModel(modelID) {
158 return newResponsesLanguageModel(modelID, o.options.name, client), nil
159 }
160
161 return newLanguageModel(
162 modelID,
163 o.options.name,
164 client,
165 o.options.languageModelOptions...,
166 ), nil
167}
168
169func (o *provider) Name() string {
170 return Name
171}