openai.go

  1// Package openai provides an implementation of the fantasy AI SDK for OpenAI's language models.
  2package openai
  3
  4import (
  5	"cmp"
  6	"context"
  7	"maps"
  8
  9	"charm.land/fantasy"
 10	"charm.land/fantasy/providers/internal/httpheaders"
 11	"github.com/openai/openai-go/v2"
 12	"github.com/openai/openai-go/v2/option"
 13)
 14
 15const (
 16	// Name is the name of the OpenAI provider.
 17	Name = "openai"
 18	// DefaultURL is the default URL for the OpenAI API.
 19	DefaultURL = "https://api.openai.com/v1"
 20)
 21
 22type provider struct {
 23	options options
 24}
 25
 26type options struct {
 27	baseURL              string
 28	apiKey               string
 29	organization         string
 30	project              string
 31	name                 string
 32	useResponsesAPI      bool
 33	headers              map[string]string
 34	userAgent            string
 35	client               option.HTTPClient
 36	sdkOptions           []option.RequestOption
 37	objectMode           fantasy.ObjectMode
 38	languageModelOptions []LanguageModelOption
 39}
 40
 41// Option defines a function that configures OpenAI provider options.
 42type Option = func(*options)
 43
 44// New creates a new OpenAI provider with the given options.
 45func New(opts ...Option) (fantasy.Provider, error) {
 46	providerOptions := options{
 47		headers:              map[string]string{},
 48		languageModelOptions: make([]LanguageModelOption, 0),
 49	}
 50	for _, o := range opts {
 51		o(&providerOptions)
 52	}
 53
 54	providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
 55	providerOptions.name = cmp.Or(providerOptions.name, Name)
 56
 57	if providerOptions.organization != "" {
 58		providerOptions.headers["OpenAi-Organization"] = providerOptions.organization
 59	}
 60	if providerOptions.project != "" {
 61		providerOptions.headers["OpenAi-Project"] = providerOptions.project
 62	}
 63
 64	return &provider{options: providerOptions}, nil
 65}
 66
 67// WithBaseURL sets the base URL for the OpenAI provider.
 68func WithBaseURL(baseURL string) Option {
 69	return func(o *options) {
 70		o.baseURL = baseURL
 71	}
 72}
 73
 74// WithAPIKey sets the API key for the OpenAI provider.
 75func WithAPIKey(apiKey string) Option {
 76	return func(o *options) {
 77		o.apiKey = apiKey
 78	}
 79}
 80
 81// WithOrganization sets the organization for the OpenAI provider.
 82func WithOrganization(organization string) Option {
 83	return func(o *options) {
 84		o.organization = organization
 85	}
 86}
 87
 88// WithProject sets the project for the OpenAI provider.
 89func WithProject(project string) Option {
 90	return func(o *options) {
 91		o.project = project
 92	}
 93}
 94
 95// WithName sets the name for the OpenAI provider.
 96func WithName(name string) Option {
 97	return func(o *options) {
 98		o.name = name
 99	}
100}
101
102// WithHeaders sets the headers for the OpenAI provider.
103func WithHeaders(headers map[string]string) Option {
104	return func(o *options) {
105		maps.Copy(o.headers, headers)
106	}
107}
108
109// WithHTTPClient sets the HTTP client for the OpenAI provider.
110func WithHTTPClient(client option.HTTPClient) Option {
111	return func(o *options) {
112		o.client = client
113	}
114}
115
116// WithSDKOptions sets the SDK options for the OpenAI provider.
117func WithSDKOptions(opts ...option.RequestOption) Option {
118	return func(o *options) {
119		o.sdkOptions = append(o.sdkOptions, opts...)
120	}
121}
122
123// WithLanguageModelOptions sets the language model options for the OpenAI provider.
124func WithLanguageModelOptions(opts ...LanguageModelOption) Option {
125	return func(o *options) {
126		o.languageModelOptions = append(o.languageModelOptions, opts...)
127	}
128}
129
130// WithUseResponsesAPI configures the provider to use the responses API for models that support it.
131func WithUseResponsesAPI() Option {
132	return func(o *options) {
133		o.useResponsesAPI = true
134	}
135}
136
137// WithUserAgent sets an explicit User-Agent header, overriding the default and any
138// value set via WithHeaders.
139func WithUserAgent(ua string) Option {
140	return func(o *options) {
141		o.userAgent = ua
142	}
143}
144
145// WithObjectMode sets the object generation mode.
146func WithObjectMode(om fantasy.ObjectMode) Option {
147	return func(o *options) {
148		// not supported
149		if om == fantasy.ObjectModeJSON {
150			om = fantasy.ObjectModeAuto
151		}
152		o.objectMode = om
153	}
154}
155
156// LanguageModel implements fantasy.Provider.
157func (o *provider) LanguageModel(_ context.Context, modelID string) (fantasy.LanguageModel, error) {
158	openaiClientOptions := make([]option.RequestOption, 0, 5+len(o.options.headers)+len(o.options.sdkOptions))
159	openaiClientOptions = append(openaiClientOptions, option.WithMaxRetries(0))
160
161	if o.options.apiKey != "" {
162		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
163	}
164	if o.options.baseURL != "" {
165		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
166	}
167
168	defaultUA := httpheaders.DefaultUserAgent(fantasy.Version)
169	resolved := httpheaders.ResolveHeaders(o.options.headers, o.options.userAgent, defaultUA)
170	for key, value := range resolved {
171		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
172	}
173
174	if o.options.client != nil {
175		openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
176	}
177
178	openaiClientOptions = append(openaiClientOptions, o.options.sdkOptions...)
179
180	client := openai.NewClient(openaiClientOptions...)
181
182	if o.options.useResponsesAPI && IsResponsesModel(modelID) {
183		// Not supported for responses API
184		objectMode := o.options.objectMode
185		if objectMode == fantasy.ObjectModeJSON {
186			objectMode = fantasy.ObjectModeAuto
187		}
188		return newResponsesLanguageModel(modelID, o.options.name, client, objectMode), nil
189	}
190
191	o.options.languageModelOptions = append(o.options.languageModelOptions, WithLanguageModelObjectMode(o.options.objectMode))
192
193	return newLanguageModel(
194		modelID,
195		o.options.name,
196		client,
197		o.options.languageModelOptions...,
198	), nil
199}
200
201func (o *provider) Name() string {
202	return Name
203}