1// Package openai provides an implementation of the fantasy AI SDK for OpenAI's language models.
2package openai
3
4import (
5 "cmp"
6 "context"
7 "maps"
8
9 "charm.land/fantasy"
10 "github.com/openai/openai-go/v2"
11 "github.com/openai/openai-go/v2/option"
12)
13
14const (
15 // Name is the name of the OpenAI provider.
16 Name = "openai"
17 // DefaultURL is the default URL for the OpenAI API.
18 DefaultURL = "https://api.openai.com/v1"
19)
20
21type provider struct {
22 options options
23}
24
25type options struct {
26 baseURL string
27 apiKey string
28 organization string
29 project string
30 name string
31 useResponsesAPI bool
32 headers map[string]string
33 client option.HTTPClient
34 sdkOptions []option.RequestOption
35 objectMode fantasy.ObjectMode
36 languageModelOptions []LanguageModelOption
37}
38
39// Option defines a function that configures OpenAI provider options.
40type Option = func(*options)
41
42// New creates a new OpenAI provider with the given options.
43func New(opts ...Option) (fantasy.Provider, error) {
44 providerOptions := options{
45 headers: map[string]string{},
46 languageModelOptions: make([]LanguageModelOption, 0),
47 }
48 for _, o := range opts {
49 o(&providerOptions)
50 }
51
52 providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
53 providerOptions.name = cmp.Or(providerOptions.name, Name)
54
55 if providerOptions.organization != "" {
56 providerOptions.headers["OpenAi-Organization"] = providerOptions.organization
57 }
58 if providerOptions.project != "" {
59 providerOptions.headers["OpenAi-Project"] = providerOptions.project
60 }
61
62 return &provider{options: providerOptions}, nil
63}
64
65// WithBaseURL sets the base URL for the OpenAI provider.
66func WithBaseURL(baseURL string) Option {
67 return func(o *options) {
68 o.baseURL = baseURL
69 }
70}
71
72// WithAPIKey sets the API key for the OpenAI provider.
73func WithAPIKey(apiKey string) Option {
74 return func(o *options) {
75 o.apiKey = apiKey
76 }
77}
78
79// WithOrganization sets the organization for the OpenAI provider.
80func WithOrganization(organization string) Option {
81 return func(o *options) {
82 o.organization = organization
83 }
84}
85
86// WithProject sets the project for the OpenAI provider.
87func WithProject(project string) Option {
88 return func(o *options) {
89 o.project = project
90 }
91}
92
93// WithName sets the name for the OpenAI provider.
94func WithName(name string) Option {
95 return func(o *options) {
96 o.name = name
97 }
98}
99
100// WithHeaders sets the headers for the OpenAI provider.
101func WithHeaders(headers map[string]string) Option {
102 return func(o *options) {
103 maps.Copy(o.headers, headers)
104 }
105}
106
107// WithHTTPClient sets the HTTP client for the OpenAI provider.
108func WithHTTPClient(client option.HTTPClient) Option {
109 return func(o *options) {
110 o.client = client
111 }
112}
113
114// WithSDKOptions sets the SDK options for the OpenAI provider.
115func WithSDKOptions(opts ...option.RequestOption) Option {
116 return func(o *options) {
117 o.sdkOptions = append(o.sdkOptions, opts...)
118 }
119}
120
121// WithLanguageModelOptions sets the language model options for the OpenAI provider.
122func WithLanguageModelOptions(opts ...LanguageModelOption) Option {
123 return func(o *options) {
124 o.languageModelOptions = append(o.languageModelOptions, opts...)
125 }
126}
127
128// WithUseResponsesAPI configures the provider to use the responses API for models that support it.
129func WithUseResponsesAPI() Option {
130 return func(o *options) {
131 o.useResponsesAPI = true
132 }
133}
134
135// WithObjectMode sets the object generation mode.
136func WithObjectMode(om fantasy.ObjectMode) Option {
137 return func(o *options) {
138 // not supported
139 if om == fantasy.ObjectModeJSON {
140 om = fantasy.ObjectModeAuto
141 }
142 o.objectMode = om
143 }
144}
145
146// LanguageModel implements fantasy.Provider.
147func (o *provider) LanguageModel(_ context.Context, modelID string) (fantasy.LanguageModel, error) {
148 openaiClientOptions := make([]option.RequestOption, 0, 5+len(o.options.headers)+len(o.options.sdkOptions))
149 openaiClientOptions = append(openaiClientOptions, option.WithMaxRetries(0))
150
151 if o.options.apiKey != "" {
152 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
153 }
154 if o.options.baseURL != "" {
155 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
156 }
157
158 for key, value := range o.options.headers {
159 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
160 }
161
162 if o.options.client != nil {
163 openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
164 }
165
166 openaiClientOptions = append(openaiClientOptions, o.options.sdkOptions...)
167
168 client := openai.NewClient(openaiClientOptions...)
169
170 if o.options.useResponsesAPI && IsResponsesModel(modelID) {
171 // Not supported for responses API
172 objectMode := o.options.objectMode
173 if objectMode == fantasy.ObjectModeJSON {
174 objectMode = fantasy.ObjectModeAuto
175 }
176 return newResponsesLanguageModel(modelID, o.options.name, client, objectMode), nil
177 }
178
179 o.options.languageModelOptions = append(o.options.languageModelOptions, WithLanguageModelObjectMode(o.options.objectMode))
180
181 return newLanguageModel(
182 modelID,
183 o.options.name,
184 client,
185 o.options.languageModelOptions...,
186 ), nil
187}
188
189func (o *provider) Name() string {
190 return Name
191}