1// Package openai provides an implementation of the fantasy AI SDK for OpenAI's language models.
  2package openai
  3
  4import (
  5	"cmp"
  6	"context"
  7	"maps"
  8
  9	"charm.land/fantasy"
 10	"github.com/openai/openai-go/v2"
 11	"github.com/openai/openai-go/v2/option"
 12)
 13
 14const (
 15	// Name is the name of the OpenAI provider.
 16	Name = "openai"
 17	// DefaultURL is the default URL for the OpenAI API.
 18	DefaultURL = "https://api.openai.com/v1"
 19)
 20
 21type provider struct {
 22	options options
 23}
 24
 25type options struct {
 26	baseURL              string
 27	apiKey               string
 28	organization         string
 29	project              string
 30	name                 string
 31	useResponsesAPI      bool
 32	headers              map[string]string
 33	client               option.HTTPClient
 34	sdkOptions           []option.RequestOption
 35	languageModelOptions []LanguageModelOption
 36}
 37
 38// Option defines a function that configures OpenAI provider options.
 39type Option = func(*options)
 40
 41// New creates a new OpenAI provider with the given options.
 42func New(opts ...Option) (fantasy.Provider, error) {
 43	providerOptions := options{
 44		headers:              map[string]string{},
 45		languageModelOptions: make([]LanguageModelOption, 0),
 46	}
 47	for _, o := range opts {
 48		o(&providerOptions)
 49	}
 50
 51	providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
 52	providerOptions.name = cmp.Or(providerOptions.name, Name)
 53
 54	if providerOptions.organization != "" {
 55		providerOptions.headers["OpenAi-Organization"] = providerOptions.organization
 56	}
 57	if providerOptions.project != "" {
 58		providerOptions.headers["OpenAi-Project"] = providerOptions.project
 59	}
 60
 61	return &provider{options: providerOptions}, nil
 62}
 63
 64// WithBaseURL sets the base URL for the OpenAI provider.
 65func WithBaseURL(baseURL string) Option {
 66	return func(o *options) {
 67		o.baseURL = baseURL
 68	}
 69}
 70
 71// WithAPIKey sets the API key for the OpenAI provider.
 72func WithAPIKey(apiKey string) Option {
 73	return func(o *options) {
 74		o.apiKey = apiKey
 75	}
 76}
 77
 78// WithOrganization sets the organization for the OpenAI provider.
 79func WithOrganization(organization string) Option {
 80	return func(o *options) {
 81		o.organization = organization
 82	}
 83}
 84
 85// WithProject sets the project for the OpenAI provider.
 86func WithProject(project string) Option {
 87	return func(o *options) {
 88		o.project = project
 89	}
 90}
 91
 92// WithName sets the name for the OpenAI provider.
 93func WithName(name string) Option {
 94	return func(o *options) {
 95		o.name = name
 96	}
 97}
 98
 99// WithHeaders sets the headers for the OpenAI provider.
100func WithHeaders(headers map[string]string) Option {
101	return func(o *options) {
102		maps.Copy(o.headers, headers)
103	}
104}
105
106// WithHTTPClient sets the HTTP client for the OpenAI provider.
107func WithHTTPClient(client option.HTTPClient) Option {
108	return func(o *options) {
109		o.client = client
110	}
111}
112
113// WithSDKOptions sets the SDK options for the OpenAI provider.
114func WithSDKOptions(opts ...option.RequestOption) Option {
115	return func(o *options) {
116		o.sdkOptions = append(o.sdkOptions, opts...)
117	}
118}
119
120// WithLanguageModelOptions sets the language model options for the OpenAI provider.
121func WithLanguageModelOptions(opts ...LanguageModelOption) Option {
122	return func(o *options) {
123		o.languageModelOptions = append(o.languageModelOptions, opts...)
124	}
125}
126
127// WithUseResponsesAPI configures the provider to use the responses API for models that support it.
128func WithUseResponsesAPI() Option {
129	return func(o *options) {
130		o.useResponsesAPI = true
131	}
132}
133
134// LanguageModel implements fantasy.Provider.
135func (o *provider) LanguageModel(_ context.Context, modelID string) (fantasy.LanguageModel, error) {
136	openaiClientOptions := make([]option.RequestOption, 0, 5+len(o.options.headers)+len(o.options.sdkOptions))
137
138	if o.options.apiKey != "" {
139		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
140	}
141	if o.options.baseURL != "" {
142		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
143	}
144
145	for key, value := range o.options.headers {
146		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
147	}
148
149	if o.options.client != nil {
150		openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
151	}
152
153	openaiClientOptions = append(openaiClientOptions, o.options.sdkOptions...)
154
155	client := openai.NewClient(openaiClientOptions...)
156
157	if o.options.useResponsesAPI && IsResponsesModel(modelID) {
158		return newResponsesLanguageModel(modelID, o.options.name, client), nil
159	}
160
161	return newLanguageModel(
162		modelID,
163		o.options.name,
164		client,
165		o.options.languageModelOptions...,
166	), nil
167}
168
169func (o *provider) Name() string {
170	return Name
171}