1// Package openai provides an implementation of the fantasy AI SDK for OpenAI's language models.
2package openai
3
4import (
5 "cmp"
6 "maps"
7
8 "charm.land/fantasy"
9 "github.com/openai/openai-go/v2"
10 "github.com/openai/openai-go/v2/option"
11)
12
13const (
14 // Name is the name of the OpenAI provider.
15 Name = "openai"
16 // DefaultURL is the default URL for the OpenAI API.
17 DefaultURL = "https://api.openai.com/v1"
18)
19
20type provider struct {
21 options options
22}
23
24type options struct {
25 baseURL string
26 apiKey string
27 organization string
28 project string
29 name string
30 useResponsesAPI bool
31 headers map[string]string
32 client option.HTTPClient
33 sdkOptions []option.RequestOption
34 languageModelOptions []LanguageModelOption
35}
36
37// Option defines a function that configures OpenAI provider options.
38type Option = func(*options)
39
40// New creates a new OpenAI provider with the given options.
41func New(opts ...Option) (fantasy.Provider, error) {
42 providerOptions := options{
43 headers: map[string]string{},
44 languageModelOptions: make([]LanguageModelOption, 0),
45 }
46 for _, o := range opts {
47 o(&providerOptions)
48 }
49
50 providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
51 providerOptions.name = cmp.Or(providerOptions.name, Name)
52
53 if providerOptions.organization != "" {
54 providerOptions.headers["OpenAi-Organization"] = providerOptions.organization
55 }
56 if providerOptions.project != "" {
57 providerOptions.headers["OpenAi-Project"] = providerOptions.project
58 }
59
60 return &provider{options: providerOptions}, nil
61}
62
63// WithBaseURL sets the base URL for the OpenAI provider.
64func WithBaseURL(baseURL string) Option {
65 return func(o *options) {
66 o.baseURL = baseURL
67 }
68}
69
70// WithAPIKey sets the API key for the OpenAI provider.
71func WithAPIKey(apiKey string) Option {
72 return func(o *options) {
73 o.apiKey = apiKey
74 }
75}
76
77// WithOrganization sets the organization for the OpenAI provider.
78func WithOrganization(organization string) Option {
79 return func(o *options) {
80 o.organization = organization
81 }
82}
83
84// WithProject sets the project for the OpenAI provider.
85func WithProject(project string) Option {
86 return func(o *options) {
87 o.project = project
88 }
89}
90
91// WithName sets the name for the OpenAI provider.
92func WithName(name string) Option {
93 return func(o *options) {
94 o.name = name
95 }
96}
97
98// WithHeaders sets the headers for the OpenAI provider.
99func WithHeaders(headers map[string]string) Option {
100 return func(o *options) {
101 maps.Copy(o.headers, headers)
102 }
103}
104
105// WithHTTPClient sets the HTTP client for the OpenAI provider.
106func WithHTTPClient(client option.HTTPClient) Option {
107 return func(o *options) {
108 o.client = client
109 }
110}
111
112// WithSDKOptions sets the SDK options for the OpenAI provider.
113func WithSDKOptions(opts ...option.RequestOption) Option {
114 return func(o *options) {
115 o.sdkOptions = append(o.sdkOptions, opts...)
116 }
117}
118
119// WithLanguageModelOptions sets the language model options for the OpenAI provider.
120func WithLanguageModelOptions(opts ...LanguageModelOption) Option {
121 return func(o *options) {
122 o.languageModelOptions = append(o.languageModelOptions, opts...)
123 }
124}
125
126// WithUseResponsesAPI configures the provider to use the responses API for models that support it.
127func WithUseResponsesAPI() Option {
128 return func(o *options) {
129 o.useResponsesAPI = true
130 }
131}
132
133// LanguageModel implements fantasy.Provider.
134func (o *provider) LanguageModel(modelID string) (fantasy.LanguageModel, error) {
135 openaiClientOptions := make([]option.RequestOption, 0, 5+len(o.options.headers)+len(o.options.sdkOptions))
136
137 if o.options.apiKey != "" {
138 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
139 }
140 if o.options.baseURL != "" {
141 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
142 }
143
144 for key, value := range o.options.headers {
145 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
146 }
147
148 if o.options.client != nil {
149 openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
150 }
151
152 openaiClientOptions = append(openaiClientOptions, o.options.sdkOptions...)
153
154 client := openai.NewClient(openaiClientOptions...)
155
156 if o.options.useResponsesAPI && IsResponsesModel(modelID) {
157 return newResponsesLanguageModel(modelID, o.options.name, client), nil
158 }
159
160 return newLanguageModel(
161 modelID,
162 o.options.name,
163 client,
164 o.options.languageModelOptions...,
165 ), nil
166}
167
168func (o *provider) Name() string {
169 return Name
170}