openai.go

  1// Package openai provides an implementation of the fantasy AI SDK for OpenAI's language models.
  2package openai
  3
  4import (
  5	"cmp"
  6	"maps"
  7
  8	"charm.land/fantasy"
  9	"github.com/openai/openai-go/v2"
 10	"github.com/openai/openai-go/v2/option"
 11)
 12
 13const (
 14	// Name is the name of the OpenAI provider.
 15	Name = "openai"
 16	// DefaultURL is the default URL for the OpenAI API.
 17	DefaultURL = "https://api.openai.com/v1"
 18)
 19
 20type provider struct {
 21	options options
 22}
 23
 24type options struct {
 25	baseURL              string
 26	apiKey               string
 27	organization         string
 28	project              string
 29	name                 string
 30	useResponsesAPI      bool
 31	headers              map[string]string
 32	client               option.HTTPClient
 33	sdkOptions           []option.RequestOption
 34	languageModelOptions []LanguageModelOption
 35}
 36
 37// Option defines a function that configures OpenAI provider options.
 38type Option = func(*options)
 39
 40// New creates a new OpenAI provider with the given options.
 41func New(opts ...Option) (fantasy.Provider, error) {
 42	providerOptions := options{
 43		headers:              map[string]string{},
 44		languageModelOptions: make([]LanguageModelOption, 0),
 45	}
 46	for _, o := range opts {
 47		o(&providerOptions)
 48	}
 49
 50	providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
 51	providerOptions.name = cmp.Or(providerOptions.name, Name)
 52
 53	if providerOptions.organization != "" {
 54		providerOptions.headers["OpenAi-Organization"] = providerOptions.organization
 55	}
 56	if providerOptions.project != "" {
 57		providerOptions.headers["OpenAi-Project"] = providerOptions.project
 58	}
 59
 60	return &provider{options: providerOptions}, nil
 61}
 62
 63// WithBaseURL sets the base URL for the OpenAI provider.
 64func WithBaseURL(baseURL string) Option {
 65	return func(o *options) {
 66		o.baseURL = baseURL
 67	}
 68}
 69
 70// WithAPIKey sets the API key for the OpenAI provider.
 71func WithAPIKey(apiKey string) Option {
 72	return func(o *options) {
 73		o.apiKey = apiKey
 74	}
 75}
 76
 77// WithOrganization sets the organization for the OpenAI provider.
 78func WithOrganization(organization string) Option {
 79	return func(o *options) {
 80		o.organization = organization
 81	}
 82}
 83
 84// WithProject sets the project for the OpenAI provider.
 85func WithProject(project string) Option {
 86	return func(o *options) {
 87		o.project = project
 88	}
 89}
 90
 91// WithName sets the name for the OpenAI provider.
 92func WithName(name string) Option {
 93	return func(o *options) {
 94		o.name = name
 95	}
 96}
 97
 98// WithHeaders sets the headers for the OpenAI provider.
 99func WithHeaders(headers map[string]string) Option {
100	return func(o *options) {
101		maps.Copy(o.headers, headers)
102	}
103}
104
105// WithHTTPClient sets the HTTP client for the OpenAI provider.
106func WithHTTPClient(client option.HTTPClient) Option {
107	return func(o *options) {
108		o.client = client
109	}
110}
111
112// WithSDKOptions sets the SDK options for the OpenAI provider.
113func WithSDKOptions(opts ...option.RequestOption) Option {
114	return func(o *options) {
115		o.sdkOptions = append(o.sdkOptions, opts...)
116	}
117}
118
119// WithLanguageModelOptions sets the language model options for the OpenAI provider.
120func WithLanguageModelOptions(opts ...LanguageModelOption) Option {
121	return func(o *options) {
122		o.languageModelOptions = append(o.languageModelOptions, opts...)
123	}
124}
125
126// WithUseResponsesAPI configures the provider to use the responses API for models that support it.
127func WithUseResponsesAPI() Option {
128	return func(o *options) {
129		o.useResponsesAPI = true
130	}
131}
132
133// LanguageModel implements fantasy.Provider.
134func (o *provider) LanguageModel(modelID string) (fantasy.LanguageModel, error) {
135	openaiClientOptions := make([]option.RequestOption, 0, 5+len(o.options.headers)+len(o.options.sdkOptions))
136
137	if o.options.apiKey != "" {
138		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
139	}
140	if o.options.baseURL != "" {
141		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
142	}
143
144	for key, value := range o.options.headers {
145		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
146	}
147
148	if o.options.client != nil {
149		openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
150	}
151
152	openaiClientOptions = append(openaiClientOptions, o.options.sdkOptions...)
153
154	client := openai.NewClient(openaiClientOptions...)
155
156	if o.options.useResponsesAPI && IsResponsesModel(modelID) {
157		return newResponsesLanguageModel(modelID, o.options.name, client), nil
158	}
159
160	return newLanguageModel(
161		modelID,
162		o.options.name,
163		client,
164		o.options.languageModelOptions...,
165	), nil
166}
167
168func (o *provider) Name() string {
169	return Name
170}