1package openai
  2
  3import (
  4	"cmp"
  5	"maps"
  6
  7	"github.com/charmbracelet/fantasy/ai"
  8	"github.com/openai/openai-go/v2"
  9	"github.com/openai/openai-go/v2/option"
 10)
 11
 12const (
 13	Name       = "openai"
 14	DefaultURL = "https://api.openai.com/v1"
 15)
 16
 17type provider struct {
 18	options options
 19}
 20
 21type options struct {
 22	baseURL              string
 23	apiKey               string
 24	organization         string
 25	project              string
 26	name                 string
 27	headers              map[string]string
 28	client               option.HTTPClient
 29	languageModelOptions []LanguageModelOption
 30}
 31
 32type Option = func(*options)
 33
 34func New(opts ...Option) ai.Provider {
 35	providerOptions := options{
 36		headers:              map[string]string{},
 37		languageModelOptions: make([]LanguageModelOption, 0),
 38	}
 39	for _, o := range opts {
 40		o(&providerOptions)
 41	}
 42
 43	providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
 44	providerOptions.name = cmp.Or(providerOptions.name, Name)
 45
 46	if providerOptions.organization != "" {
 47		providerOptions.headers["OpenAi-Organization"] = providerOptions.organization
 48	}
 49	if providerOptions.project != "" {
 50		providerOptions.headers["OpenAi-Project"] = providerOptions.project
 51	}
 52
 53	return &provider{options: providerOptions}
 54}
 55
 56func WithBaseURL(baseURL string) Option {
 57	return func(o *options) {
 58		o.baseURL = baseURL
 59	}
 60}
 61
 62func WithAPIKey(apiKey string) Option {
 63	return func(o *options) {
 64		o.apiKey = apiKey
 65	}
 66}
 67
 68func WithOrganization(organization string) Option {
 69	return func(o *options) {
 70		o.organization = organization
 71	}
 72}
 73
 74func WithProject(project string) Option {
 75	return func(o *options) {
 76		o.project = project
 77	}
 78}
 79
 80func WithName(name string) Option {
 81	return func(o *options) {
 82		o.name = name
 83	}
 84}
 85
 86func WithHeaders(headers map[string]string) Option {
 87	return func(o *options) {
 88		maps.Copy(o.headers, headers)
 89	}
 90}
 91
 92func WithHTTPClient(client option.HTTPClient) Option {
 93	return func(o *options) {
 94		o.client = client
 95	}
 96}
 97
 98func WithLanguageModelOptions(opts ...LanguageModelOption) Option {
 99	return func(o *options) {
100		o.languageModelOptions = append(o.languageModelOptions, opts...)
101	}
102}
103
104// LanguageModel implements ai.Provider.
105func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
106	openaiClientOptions := []option.RequestOption{}
107	if o.options.apiKey != "" {
108		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
109	}
110	if o.options.baseURL != "" {
111		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
112	}
113
114	for key, value := range o.options.headers {
115		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
116	}
117
118	if o.options.client != nil {
119		openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
120	}
121
122	return newLanguageModel(
123		modelID,
124		o.options.name,
125		openai.NewClient(openaiClientOptions...),
126		o.options.languageModelOptions...,
127	), nil
128}
129
130func (o *provider) ParseOptions(data map[string]any) (ai.ProviderOptionsData, error) {
131	var options ProviderOptions
132	if err := ai.ParseOptions(data, &options); err != nil {
133		return nil, err
134	}
135	return &options, nil
136}
137
138func (o *provider) Name() string {
139	return Name
140}