openai.go

  1// Package openai contains the openai provider
  2package openai
  3
  4import (
  5	"cmp"
  6	"maps"
  7
  8	"github.com/charmbracelet/fantasy/ai"
  9	"github.com/openai/openai-go/v2"
 10	"github.com/openai/openai-go/v2/option"
 11)
 12
 13const (
 14	Name       = "openai"
 15	DefaultURL = "https://api.openai.com/v1"
 16)
 17
 18type provider struct {
 19	options options
 20}
 21
 22type options struct {
 23	baseURL              string
 24	apiKey               string
 25	organization         string
 26	project              string
 27	name                 string
 28	useResponsesAPI      bool
 29	headers              map[string]string
 30	client               option.HTTPClient
 31	sdkOptions           []option.RequestOption
 32	languageModelOptions []LanguageModelOption
 33}
 34
 35type Option = func(*options)
 36
 37func New(opts ...Option) ai.Provider {
 38	providerOptions := options{
 39		headers:              map[string]string{},
 40		languageModelOptions: make([]LanguageModelOption, 0),
 41	}
 42	for _, o := range opts {
 43		o(&providerOptions)
 44	}
 45
 46	providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
 47	providerOptions.name = cmp.Or(providerOptions.name, Name)
 48
 49	if providerOptions.organization != "" {
 50		providerOptions.headers["OpenAi-Organization"] = providerOptions.organization
 51	}
 52	if providerOptions.project != "" {
 53		providerOptions.headers["OpenAi-Project"] = providerOptions.project
 54	}
 55
 56	return &provider{options: providerOptions}
 57}
 58
 59func WithBaseURL(baseURL string) Option {
 60	return func(o *options) {
 61		o.baseURL = baseURL
 62	}
 63}
 64
 65func WithAPIKey(apiKey string) Option {
 66	return func(o *options) {
 67		o.apiKey = apiKey
 68	}
 69}
 70
 71func WithOrganization(organization string) Option {
 72	return func(o *options) {
 73		o.organization = organization
 74	}
 75}
 76
 77func WithProject(project string) Option {
 78	return func(o *options) {
 79		o.project = project
 80	}
 81}
 82
 83func WithName(name string) Option {
 84	return func(o *options) {
 85		o.name = name
 86	}
 87}
 88
 89func WithHeaders(headers map[string]string) Option {
 90	return func(o *options) {
 91		maps.Copy(o.headers, headers)
 92	}
 93}
 94
 95func WithHTTPClient(client option.HTTPClient) Option {
 96	return func(o *options) {
 97		o.client = client
 98	}
 99}
100
101func WithSDKOptions(opts ...option.RequestOption) Option {
102	return func(o *options) {
103		o.sdkOptions = append(o.sdkOptions, opts...)
104	}
105}
106
107func WithLanguageModelOptions(opts ...LanguageModelOption) Option {
108	return func(o *options) {
109		o.languageModelOptions = append(o.languageModelOptions, opts...)
110	}
111}
112
113// WithUseResponsesAPI makes it so the provider uses responses API for models that support it.
114func WithUseResponsesAPI() Option {
115	return func(o *options) {
116		o.useResponsesAPI = true
117	}
118}
119
120// LanguageModel implements ai.Provider.
121func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
122	openaiClientOptions := make([]option.RequestOption, 0, 5+len(o.options.headers)+len(o.options.sdkOptions))
123
124	if o.options.apiKey != "" {
125		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
126	}
127	if o.options.baseURL != "" {
128		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
129	}
130
131	for key, value := range o.options.headers {
132		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
133	}
134
135	if o.options.client != nil {
136		openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
137	}
138
139	openaiClientOptions = append(openaiClientOptions, o.options.sdkOptions...)
140
141	client := openai.NewClient(openaiClientOptions...)
142
143	if o.options.useResponsesAPI && IsResponsesModel(modelID) {
144		return newResponsesLanguageModel(modelID, o.options.name, client), nil
145	}
146
147	return newLanguageModel(
148		modelID,
149		o.options.name,
150		client,
151		o.options.languageModelOptions...,
152	), nil
153}
154
155func (o *provider) Name() string {
156	return Name
157}