1package providertests
2
3import (
4 "net/http"
5 "testing"
6
7 "charm.land/fantasy"
8 "charm.land/fantasy/providers/bedrock"
9 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
10)
11
12func TestBedrockCommon(t *testing.T) {
13 testCommon(t, []builderPair{
14 {"bedrock-anthropic-claude-3-sonnet", builderBedrockClaude3Sonnet, nil, nil},
15 {"bedrock-anthropic-claude-3-haiku", builderBedrockClaude3Haiku, nil, nil},
16 })
17}
18
19func builderBedrockClaude3Sonnet(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
20 t.Setenv("AWS_REGION", "us-east-1")
21 provider, err := bedrock.New(
22 bedrock.WithHTTPClient(&http.Client{Transport: r}),
23 bedrock.WithAPIKey("dummy"),
24 )
25 if err != nil {
26 return nil, err
27 }
28 return provider.LanguageModel(t.Context(), "anthropic.claude-3-sonnet-20240229-v1:0")
29}
30
31func builderBedrockClaude3Haiku(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
32 t.Setenv("AWS_REGION", "us-east-1")
33 provider, err := bedrock.New(
34 bedrock.WithHTTPClient(&http.Client{Transport: r}),
35 bedrock.WithAPIKey("dummy"),
36 )
37 if err != nil {
38 return nil, err
39 }
40 return provider.LanguageModel(t.Context(), "anthropic.claude-3-haiku-20240307-v1:0")
41}