1package openai
  2
  3import (
  4	"cmp"
  5	"maps"
  6
  7	"github.com/charmbracelet/fantasy/ai"
  8	"github.com/openai/openai-go/v2"
  9	"github.com/openai/openai-go/v2/option"
 10)
 11
 12const (
 13	Name       = "openai"
 14	DefaultURL = "https://api.openai.com/v1"
 15)
 16
 17type provider struct {
 18	options options
 19}
 20
 21type options struct {
 22	baseURL              string
 23	apiKey               string
 24	organization         string
 25	project              string
 26	name                 string
 27	headers              map[string]string
 28	client               option.HTTPClient
 29	sdkOptions           []option.RequestOption
 30	languageModelOptions []LanguageModelOption
 31}
 32
 33type Option = func(*options)
 34
 35func New(opts ...Option) ai.Provider {
 36	providerOptions := options{
 37		headers:              map[string]string{},
 38		languageModelOptions: make([]LanguageModelOption, 0),
 39	}
 40	for _, o := range opts {
 41		o(&providerOptions)
 42	}
 43
 44	providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
 45	providerOptions.name = cmp.Or(providerOptions.name, Name)
 46
 47	if providerOptions.organization != "" {
 48		providerOptions.headers["OpenAi-Organization"] = providerOptions.organization
 49	}
 50	if providerOptions.project != "" {
 51		providerOptions.headers["OpenAi-Project"] = providerOptions.project
 52	}
 53
 54	return &provider{options: providerOptions}
 55}
 56
 57func WithBaseURL(baseURL string) Option {
 58	return func(o *options) {
 59		o.baseURL = baseURL
 60	}
 61}
 62
 63func WithAPIKey(apiKey string) Option {
 64	return func(o *options) {
 65		o.apiKey = apiKey
 66	}
 67}
 68
 69func WithOrganization(organization string) Option {
 70	return func(o *options) {
 71		o.organization = organization
 72	}
 73}
 74
 75func WithProject(project string) Option {
 76	return func(o *options) {
 77		o.project = project
 78	}
 79}
 80
 81func WithName(name string) Option {
 82	return func(o *options) {
 83		o.name = name
 84	}
 85}
 86
 87func WithHeaders(headers map[string]string) Option {
 88	return func(o *options) {
 89		maps.Copy(o.headers, headers)
 90	}
 91}
 92
 93func WithHTTPClient(client option.HTTPClient) Option {
 94	return func(o *options) {
 95		o.client = client
 96	}
 97}
 98
 99func WithSDKOptions(opts ...option.RequestOption) Option {
100	return func(o *options) {
101		o.sdkOptions = append(o.sdkOptions, opts...)
102	}
103}
104
105func WithLanguageModelOptions(opts ...LanguageModelOption) Option {
106	return func(o *options) {
107		o.languageModelOptions = append(o.languageModelOptions, opts...)
108	}
109}
110
111// LanguageModel implements ai.Provider.
112func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
113	openaiClientOptions := make([]option.RequestOption, 0, 5+len(o.options.headers)+len(o.options.sdkOptions))
114
115	if o.options.apiKey != "" {
116		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
117	}
118	if o.options.baseURL != "" {
119		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
120	}
121
122	for key, value := range o.options.headers {
123		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
124	}
125
126	if o.options.client != nil {
127		openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
128	}
129
130	openaiClientOptions = append(openaiClientOptions, o.options.sdkOptions...)
131
132	return newLanguageModel(
133		modelID,
134		o.options.name,
135		openai.NewClient(openaiClientOptions...),
136		o.options.languageModelOptions...,
137	), nil
138}
139
140func (o *provider) Name() string {
141	return Name
142}