1// Package openai contains the openai provider
2package openai
3
4import (
5 "cmp"
6 "maps"
7
8 "github.com/charmbracelet/fantasy/ai"
9 "github.com/openai/openai-go/v2"
10 "github.com/openai/openai-go/v2/option"
11)
12
13const (
14 Name = "openai"
15 DefaultURL = "https://api.openai.com/v1"
16)
17
18type provider struct {
19 options options
20}
21
22type options struct {
23 baseURL string
24 apiKey string
25 organization string
26 project string
27 name string
28 useResponsesAPI bool
29 headers map[string]string
30 client option.HTTPClient
31 sdkOptions []option.RequestOption
32 languageModelOptions []LanguageModelOption
33}
34
35type Option = func(*options)
36
37func New(opts ...Option) ai.Provider {
38 providerOptions := options{
39 headers: map[string]string{},
40 languageModelOptions: make([]LanguageModelOption, 0),
41 }
42 for _, o := range opts {
43 o(&providerOptions)
44 }
45
46 providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
47 providerOptions.name = cmp.Or(providerOptions.name, Name)
48
49 if providerOptions.organization != "" {
50 providerOptions.headers["OpenAi-Organization"] = providerOptions.organization
51 }
52 if providerOptions.project != "" {
53 providerOptions.headers["OpenAi-Project"] = providerOptions.project
54 }
55
56 return &provider{options: providerOptions}
57}
58
59func WithBaseURL(baseURL string) Option {
60 return func(o *options) {
61 o.baseURL = baseURL
62 }
63}
64
65func WithAPIKey(apiKey string) Option {
66 return func(o *options) {
67 o.apiKey = apiKey
68 }
69}
70
71func WithOrganization(organization string) Option {
72 return func(o *options) {
73 o.organization = organization
74 }
75}
76
77func WithProject(project string) Option {
78 return func(o *options) {
79 o.project = project
80 }
81}
82
83func WithName(name string) Option {
84 return func(o *options) {
85 o.name = name
86 }
87}
88
89func WithHeaders(headers map[string]string) Option {
90 return func(o *options) {
91 maps.Copy(o.headers, headers)
92 }
93}
94
95func WithHTTPClient(client option.HTTPClient) Option {
96 return func(o *options) {
97 o.client = client
98 }
99}
100
101func WithSDKOptions(opts ...option.RequestOption) Option {
102 return func(o *options) {
103 o.sdkOptions = append(o.sdkOptions, opts...)
104 }
105}
106
107func WithLanguageModelOptions(opts ...LanguageModelOption) Option {
108 return func(o *options) {
109 o.languageModelOptions = append(o.languageModelOptions, opts...)
110 }
111}
112
113// WithUseResponsesAPI makes it so the provider uses responses API for models that support it.
114func WithUseResponsesAPI() Option {
115 return func(o *options) {
116 o.useResponsesAPI = true
117 }
118}
119
120// LanguageModel implements ai.Provider.
121func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
122 openaiClientOptions := make([]option.RequestOption, 0, 5+len(o.options.headers)+len(o.options.sdkOptions))
123
124 if o.options.apiKey != "" {
125 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
126 }
127 if o.options.baseURL != "" {
128 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
129 }
130
131 for key, value := range o.options.headers {
132 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
133 }
134
135 if o.options.client != nil {
136 openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
137 }
138
139 openaiClientOptions = append(openaiClientOptions, o.options.sdkOptions...)
140
141 client := openai.NewClient(openaiClientOptions...)
142
143 if o.options.useResponsesAPI && IsResponsesModel(modelID) {
144 return newResponsesLanguageModel(modelID, o.options.name, client), nil
145 }
146
147 return newLanguageModel(
148 modelID,
149 o.options.name,
150 client,
151 o.options.languageModelOptions...,
152 ), nil
153}
154
155func (o *provider) Name() string {
156 return Name
157}