1package bedrock
2
3import (
4 "context"
5 "fmt"
6
7 "github.com/aws/aws-sdk-go-v2/config"
8 "github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
9 "github.com/charmbracelet/fantasy/ai"
10)
11
12const (
13 Name = "bedrock"
14 // DefaultURL = "https://bedrock-runtime.amazonaws.com"
15)
16
17type options struct {
18 name string
19 httpClient bedrockruntime.HTTPClient
20
21 // region string
22 // accessKey string
23 // secretKey string
24 // sessionToken string
25 // endpoint string
26}
27
28type provider struct {
29 options options
30 client *bedrockruntime.Client
31}
32
33type Option = func(*options)
34
35func New(ctx context.Context, opts ...Option) (ai.Provider, error) {
36 providerOptions := options{
37 name: Name,
38 }
39 for _, o := range opts {
40 o(&providerOptions)
41 }
42
43 cfg, err := config.LoadDefaultConfig(ctx) //, config.WithRegion(providerOptions.region))
44 if err != nil {
45 return nil, fmt.Errorf("fantasy: unable to load default aws config: %w", err)
46 }
47
48 // if providerOptions.accessKey != "" && providerOptions.secretKey != "" {
49 // cfg.Credentials = aws.CredentialsProviderFunc(
50 // func(ctx context.Context) (aws.Credentials, error) {
51 // return aws.Credentials{
52 // AccessKeyID: providerOptions.accessKey,
53 // SecretAccessKey: providerOptions.secretKey,
54 // SessionToken: providerOptions.sessionToken,
55 // }, nil
56 // },
57 // )
58 // }
59
60 client := bedrockruntime.NewFromConfig(
61 cfg,
62 func(o *bedrockruntime.Options) {
63 if providerOptions.httpClient != nil {
64 o.HTTPClient = providerOptions.httpClient
65 }
66 },
67 )
68
69 return &provider{
70 options: providerOptions,
71 client: client,
72 }, nil
73}
74
75func WithName(name string) Option {
76 return func(o *options) {
77 o.name = name
78 }
79}
80
81func WithHTTPClient(httpClient bedrockruntime.HTTPClient) Option {
82 return func(o *options) {
83 o.httpClient = httpClient
84 }
85}
86
87// func WithRegion(region string) Option {
88// return func(o *options) {
89// o.region = region
90// }
91// }
92
93// func WithCredentials(accessKey, secretKey, sessionToken string) Option {
94// return func(o *options) {
95// o.accessKey = accessKey
96// o.secretKey = secretKey
97// o.sessionToken = sessionToken
98// }
99// }
100
101// func WithEndpoint(endpoint string) Option {
102// return func(o *options) {
103// o.endpoint = endpoint
104// }
105// }
106
107func (b *provider) Name() string {
108 return Name
109}
110
111func (b *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
112 return languageModel{
113 modelID: modelID,
114 provider: b.options.name,
115 client: b.client,
116 }, nil
117}
118
119func (b *provider) ParseOptions(data map[string]any) (ai.ProviderOptionsData, error) {
120 var options ProviderOptions
121 if err := ai.ParseOptions(data, &options); err != nil {
122 return nil, err
123 }
124 return &options, nil
125}