1package openai
2
3import (
4 "cmp"
5 "maps"
6
7 "github.com/charmbracelet/fantasy/ai"
8 "github.com/openai/openai-go/v2"
9 "github.com/openai/openai-go/v2/option"
10)
11
12const (
13 Name = "openai"
14 DefaultURL = "https://api.openai.com/v1"
15)
16
17type provider struct {
18 options options
19}
20
21type options struct {
22 baseURL string
23 apiKey string
24 organization string
25 project string
26 name string
27 headers map[string]string
28 client option.HTTPClient
29 languageModelOptions []LanguageModelOption
30}
31
32type Option = func(*options)
33
34func New(opts ...Option) ai.Provider {
35 providerOptions := options{
36 headers: map[string]string{},
37 languageModelOptions: make([]LanguageModelOption, 0),
38 }
39 for _, o := range opts {
40 o(&providerOptions)
41 }
42
43 providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
44 providerOptions.name = cmp.Or(providerOptions.name, Name)
45
46 if providerOptions.organization != "" {
47 providerOptions.headers["OpenAi-Organization"] = providerOptions.organization
48 }
49 if providerOptions.project != "" {
50 providerOptions.headers["OpenAi-Project"] = providerOptions.project
51 }
52
53 return &provider{options: providerOptions}
54}
55
56func WithBaseURL(baseURL string) Option {
57 return func(o *options) {
58 o.baseURL = baseURL
59 }
60}
61
62func WithAPIKey(apiKey string) Option {
63 return func(o *options) {
64 o.apiKey = apiKey
65 }
66}
67
68func WithOrganization(organization string) Option {
69 return func(o *options) {
70 o.organization = organization
71 }
72}
73
74func WithProject(project string) Option {
75 return func(o *options) {
76 o.project = project
77 }
78}
79
80func WithName(name string) Option {
81 return func(o *options) {
82 o.name = name
83 }
84}
85
86func WithHeaders(headers map[string]string) Option {
87 return func(o *options) {
88 maps.Copy(o.headers, headers)
89 }
90}
91
92func WithHTTPClient(client option.HTTPClient) Option {
93 return func(o *options) {
94 o.client = client
95 }
96}
97
98func WithLanguageModelOptions(opts ...LanguageModelOption) Option {
99 return func(o *options) {
100 o.languageModelOptions = append(o.languageModelOptions, opts...)
101 }
102}
103
104// LanguageModel implements ai.Provider.
105func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
106 openaiClientOptions := []option.RequestOption{}
107 if o.options.apiKey != "" {
108 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
109 }
110 if o.options.baseURL != "" {
111 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
112 }
113
114 for key, value := range o.options.headers {
115 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
116 }
117
118 if o.options.client != nil {
119 openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
120 }
121
122 return newLanguageModel(
123 modelID,
124 o.options.name,
125 openai.NewClient(openaiClientOptions...),
126 o.options.languageModelOptions...,
127 ), nil
128}
129
130func (o *provider) ParseOptions(data map[string]any) (ai.ProviderOptionsData, error) {
131 var options ProviderOptions
132 if err := ai.ParseOptions(data, &options); err != nil {
133 return nil, err
134 }
135 return &options, nil
136}
137
138func (o *provider) Name() string {
139 return Name
140}