1package openai
2
3import (
4 "cmp"
5 "maps"
6
7 "github.com/charmbracelet/fantasy/ai"
8 "github.com/openai/openai-go/v2"
9 "github.com/openai/openai-go/v2/option"
10)
11
12const (
13 Name = "openai"
14 DefaultURL = "https://api.openai.com/v1"
15)
16
17type provider struct {
18 options options
19}
20
21type options struct {
22 baseURL string
23 apiKey string
24 organization string
25 project string
26 name string
27 headers map[string]string
28 client option.HTTPClient
29 sdkOptions []option.RequestOption
30 languageModelOptions []LanguageModelOption
31}
32
33type Option = func(*options)
34
35func New(opts ...Option) ai.Provider {
36 providerOptions := options{
37 headers: map[string]string{},
38 languageModelOptions: make([]LanguageModelOption, 0),
39 }
40 for _, o := range opts {
41 o(&providerOptions)
42 }
43
44 providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
45 providerOptions.name = cmp.Or(providerOptions.name, Name)
46
47 if providerOptions.organization != "" {
48 providerOptions.headers["OpenAi-Organization"] = providerOptions.organization
49 }
50 if providerOptions.project != "" {
51 providerOptions.headers["OpenAi-Project"] = providerOptions.project
52 }
53
54 return &provider{options: providerOptions}
55}
56
57func WithBaseURL(baseURL string) Option {
58 return func(o *options) {
59 o.baseURL = baseURL
60 }
61}
62
63func WithAPIKey(apiKey string) Option {
64 return func(o *options) {
65 o.apiKey = apiKey
66 }
67}
68
69func WithOrganization(organization string) Option {
70 return func(o *options) {
71 o.organization = organization
72 }
73}
74
75func WithProject(project string) Option {
76 return func(o *options) {
77 o.project = project
78 }
79}
80
81func WithName(name string) Option {
82 return func(o *options) {
83 o.name = name
84 }
85}
86
87func WithHeaders(headers map[string]string) Option {
88 return func(o *options) {
89 maps.Copy(o.headers, headers)
90 }
91}
92
93func WithHTTPClient(client option.HTTPClient) Option {
94 return func(o *options) {
95 o.client = client
96 }
97}
98
99func WithSDKOptions(opts ...option.RequestOption) Option {
100 return func(o *options) {
101 o.sdkOptions = append(o.sdkOptions, opts...)
102 }
103}
104
105func WithLanguageModelOptions(opts ...LanguageModelOption) Option {
106 return func(o *options) {
107 o.languageModelOptions = append(o.languageModelOptions, opts...)
108 }
109}
110
111// LanguageModel implements ai.Provider.
112func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
113 openaiClientOptions := make([]option.RequestOption, 0, 5+len(o.options.headers)+len(o.options.sdkOptions))
114
115 if o.options.apiKey != "" {
116 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
117 }
118 if o.options.baseURL != "" {
119 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
120 }
121
122 for key, value := range o.options.headers {
123 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
124 }
125
126 if o.options.client != nil {
127 openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
128 }
129
130 openaiClientOptions = append(openaiClientOptions, o.options.sdkOptions...)
131
132 return newLanguageModel(
133 modelID,
134 o.options.name,
135 openai.NewClient(openaiClientOptions...),
136 o.options.languageModelOptions...,
137 ), nil
138}
139
140func (o *provider) ParseOptions(data map[string]any) (ai.ProviderOptionsData, error) {
141 var options ProviderOptions
142 if err := ai.ParseOptions(data, &options); err != nil {
143 return nil, err
144 }
145 return &options, nil
146}
147
148func (o *provider) Name() string {
149 return Name
150}