bedrock_test.go

 1package providertests
 2
 3import (
 4	"net/http"
 5	"testing"
 6
 7	"charm.land/fantasy"
 8	"charm.land/fantasy/providers/bedrock"
 9	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
10)
11
12func TestBedrockCommon(t *testing.T) {
13	testCommon(t, []builderPair{
14		{"bedrock-anthropic-claude-3-sonnet", builderBedrockClaude3Sonnet, nil, nil},
15		{"bedrock-anthropic-claude-3-haiku", builderBedrockClaude3Haiku, nil, nil},
16	})
17}
18
19func builderBedrockClaude3Sonnet(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
20	t.Setenv("AWS_REGION", "us-east-1")
21	provider, err := bedrock.New(
22		bedrock.WithHTTPClient(&http.Client{Transport: r}),
23		bedrock.WithAPIKey("dummy"),
24	)
25	if err != nil {
26		return nil, err
27	}
28	return provider.LanguageModel(t.Context(), "anthropic.claude-3-sonnet-20240229-v1:0")
29}
30
31func builderBedrockClaude3Haiku(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
32	t.Setenv("AWS_REGION", "us-east-1")
33	provider, err := bedrock.New(
34		bedrock.WithHTTPClient(&http.Client{Transport: r}),
35		bedrock.WithAPIKey("dummy"),
36	)
37	if err != nil {
38		return nil, err
39	}
40	return provider.LanguageModel(t.Context(), "anthropic.claude-3-haiku-20240307-v1:0")
41}