1package openaicompat
 2
 3import (
 4	"charm.land/fantasy/ai"
 5	"charm.land/fantasy/openai"
 6	"github.com/openai/openai-go/v2/option"
 7)
 8
 9type options struct {
10	openaiOptions        []openai.Option
11	languageModelOptions []openai.LanguageModelOption
12	sdkOptions           []option.RequestOption
13}
14
15const (
16	Name = "openai-compat"
17)
18
19type Option = func(*options)
20
21func New(opts ...Option) ai.Provider {
22	providerOptions := options{
23		openaiOptions: []openai.Option{
24			openai.WithName(Name),
25		},
26		languageModelOptions: []openai.LanguageModelOption{
27			openai.WithLanguageModelPrepareCallFunc(PrepareCallFunc),
28			openai.WithLanguageModelStreamExtraFunc(StreamExtraFunc),
29			openai.WithLanguageModelExtraContentFunc(ExtraContentFunc),
30		},
31	}
32	for _, o := range opts {
33		o(&providerOptions)
34	}
35
36	providerOptions.openaiOptions = append(
37		providerOptions.openaiOptions,
38		openai.WithSDKOptions(providerOptions.sdkOptions...),
39		openai.WithLanguageModelOptions(providerOptions.languageModelOptions...),
40	)
41	return openai.New(providerOptions.openaiOptions...)
42}
43
44func WithBaseURL(url string) Option {
45	return func(o *options) {
46		o.openaiOptions = append(o.openaiOptions, openai.WithBaseURL(url))
47	}
48}
49
50func WithAPIKey(apiKey string) Option {
51	return func(o *options) {
52		o.openaiOptions = append(o.openaiOptions, openai.WithAPIKey(apiKey))
53	}
54}
55
56func WithName(name string) Option {
57	return func(o *options) {
58		o.openaiOptions = append(o.openaiOptions, openai.WithName(name))
59	}
60}
61
62func WithHeaders(headers map[string]string) Option {
63	return func(o *options) {
64		o.openaiOptions = append(o.openaiOptions, openai.WithHeaders(headers))
65	}
66}
67
68func WithHTTPClient(client option.HTTPClient) Option {
69	return func(o *options) {
70		o.openaiOptions = append(o.openaiOptions, openai.WithHTTPClient(client))
71	}
72}
73
74func WithSDKOptions(opts ...option.RequestOption) Option {
75	return func(o *options) {
76		o.sdkOptions = append(o.sdkOptions, opts...)
77	}
78}