bedrock.go

  1package bedrock
  2
  3import (
  4	"context"
  5	"fmt"
  6
  7	"github.com/aws/aws-sdk-go-v2/config"
  8	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
  9	"github.com/charmbracelet/fantasy/ai"
 10)
 11
 12const (
 13	Name = "bedrock"
 14	// DefaultURL = "https://bedrock-runtime.amazonaws.com"
 15)
 16
 17type options struct {
 18	name       string
 19	httpClient bedrockruntime.HTTPClient
 20
 21	// region       string
 22	// accessKey    string
 23	// secretKey    string
 24	// sessionToken string
 25	// endpoint     string
 26}
 27
 28type provider struct {
 29	options options
 30	client  *bedrockruntime.Client
 31}
 32
 33type Option = func(*options)
 34
 35func New(ctx context.Context, opts ...Option) (ai.Provider, error) {
 36	providerOptions := options{
 37		name: Name,
 38	}
 39	for _, o := range opts {
 40		o(&providerOptions)
 41	}
 42
 43	cfg, err := config.LoadDefaultConfig(ctx) //, config.WithRegion(providerOptions.region))
 44	if err != nil {
 45		return nil, fmt.Errorf("fantasy: unable to load default aws config: %w", err)
 46	}
 47
 48	// if providerOptions.accessKey != "" && providerOptions.secretKey != "" {
 49	// 	cfg.Credentials = aws.CredentialsProviderFunc(
 50	// 		func(ctx context.Context) (aws.Credentials, error) {
 51	// 			return aws.Credentials{
 52	// 				AccessKeyID:     providerOptions.accessKey,
 53	// 				SecretAccessKey: providerOptions.secretKey,
 54	// 				SessionToken:    providerOptions.sessionToken,
 55	// 			}, nil
 56	// 		},
 57	// 	)
 58	// }
 59
 60	client := bedrockruntime.NewFromConfig(
 61		cfg,
 62		func(o *bedrockruntime.Options) {
 63			if providerOptions.httpClient != nil {
 64				o.HTTPClient = providerOptions.httpClient
 65			}
 66		},
 67	)
 68
 69	return &provider{
 70		options: providerOptions,
 71		client:  client,
 72	}, nil
 73}
 74
 75func WithName(name string) Option {
 76	return func(o *options) {
 77		o.name = name
 78	}
 79}
 80
 81func WithHTTPClient(httpClient bedrockruntime.HTTPClient) Option {
 82	return func(o *options) {
 83		o.httpClient = httpClient
 84	}
 85}
 86
 87// func WithRegion(region string) Option {
 88// 	return func(o *options) {
 89// 		o.region = region
 90// 	}
 91// }
 92
 93// func WithCredentials(accessKey, secretKey, sessionToken string) Option {
 94// 	return func(o *options) {
 95// 		o.accessKey = accessKey
 96// 		o.secretKey = secretKey
 97// 		o.sessionToken = sessionToken
 98// 	}
 99// }
100
101// func WithEndpoint(endpoint string) Option {
102// 	return func(o *options) {
103// 		o.endpoint = endpoint
104// 	}
105// }
106
107func (b *provider) Name() string {
108	return Name
109}
110
111func (b *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
112	return languageModel{
113		modelID:  modelID,
114		provider: b.options.name,
115		client:   b.client,
116	}, nil
117}
118
119func (b *provider) ParseOptions(data map[string]any) (ai.ProviderOptionsData, error) {
120	var options ProviderOptions
121	if err := ai.ParseOptions(data, &options); err != nil {
122		return nil, err
123	}
124	return &options, nil
125}