diff --git a/compiler/internal/typeparams/collect.go b/compiler/internal/typeparams/collect.go index 6f0da9973..d8f42c147 100644 --- a/compiler/internal/typeparams/collect.go +++ b/compiler/internal/typeparams/collect.go @@ -99,7 +99,7 @@ func ToSlice(tpl *types.TypeParamList) []*types.TypeParam { // function, Resolver must be provided mapping the type parameters into concrete // types. type visitor struct { - instances *InstanceSet + instances *PackageInstanceSets resolver *Resolver info *types.Info } @@ -216,11 +216,11 @@ func (c *seedVisitor) Visit(n ast.Node) ast.Visitor { type Collector struct { TContext *types.Context Info *types.Info - Instances *InstanceSet + Instances *PackageInstanceSets } // Scan package files for generic instances. -func (c *Collector) Scan(files ...*ast.File) { +func (c *Collector) Scan(pkg *types.Package, files ...*ast.File) { if c.Info.Instances == nil || c.Info.Defs == nil { panic(fmt.Errorf("types.Info must have Instances and Defs populated")) } @@ -240,8 +240,8 @@ func (c *Collector) Scan(files ...*ast.File) { ast.Walk(&sc, file) } - for !c.Instances.exhausted() { - inst, _ := c.Instances.next() + for iset := c.Instances.Pkg(pkg); !iset.exhausted(); { + inst, _ := iset.next() switch typ := inst.Object.Type().(type) { case *types.Signature: v := visitor{ diff --git a/compiler/internal/typeparams/collect_test.go b/compiler/internal/typeparams/collect_test.go index 99b288552..9bd5faee4 100644 --- a/compiler/internal/typeparams/collect_test.go +++ b/compiler/internal/typeparams/collect_test.go @@ -252,12 +252,12 @@ func TestVisitor(t *testing.T) { for _, test := range tests { t.Run(test.descr, func(t *testing.T) { v := visitor{ - instances: &InstanceSet{}, + instances: &PackageInstanceSets{}, resolver: test.resolver, info: info, } ast.Walk(&v, test.node) - got := v.instances.Values() + got := v.instances.Pkg(pkg).Values() if diff := cmp.Diff(test.want, got, instanceOpts()); diff != "" { t.Errorf("Discovered instance diff (-want,+got):\n%s", diff) } @@ -285,7 +285,7 @@ func TestSeedVisitor(t *testing.T) { sv := seedVisitor{ visitor: visitor{ - instances: &InstanceSet{}, + instances: &PackageInstanceSets{}, resolver: nil, info: info, }, @@ -317,7 +317,7 @@ func TestSeedVisitor(t *testing.T) { tInst(types.Typ[types.Int64]), mInst(types.Typ[types.Int64]), } - got := sv.instances.Values() + got := sv.instances.Pkg(pkg).Values() if diff := cmp.Diff(want, got, instanceOpts()); diff != "" { t.Errorf("Instances from initialSeeder contain diff (-want,+got):\n%s", diff) } @@ -349,9 +349,9 @@ func TestCollector(t *testing.T) { c := Collector{ TContext: types.NewContext(), Info: info, - Instances: &InstanceSet{}, + Instances: &PackageInstanceSets{}, } - c.Scan(file) + c.Scan(pkg, file) inst := func(name string, tArg types.Type) Instance { return Instance{ @@ -371,12 +371,76 @@ func TestCollector(t *testing.T) { inst("fun", types.Typ[types.Int64]), inst("fun.nested", types.Typ[types.Int64]), } - got := c.Instances.Values() + got := c.Instances.Pkg(pkg).Values() if diff := cmp.Diff(want, got, instanceOpts()); diff != "" { t.Errorf("Instances from initialSeeder contain diff (-want,+got):\n%s", diff) } } +func TestCollector_CrossPackage(t *testing.T) { + f := srctesting.New(t) + const src = `package foo + type X[T any] struct {Value T} + + func F[G any](g G) { + x := X[G]{} + println(x) + } + + func DoFoo() { + F(int8(8)) + } + ` + fooFile := f.Parse("foo.go", src) + _, fooPkg := f.Check("pkg/foo", fooFile) + + const src2 = `package bar + import "pkg/foo" + func FProxy[T any](t T) { + foo.F[T](t) + } + func DoBar() { + FProxy(int16(16)) + } + ` + barFile := f.Parse("bar.go", src2) + _, barPkg := f.Check("pkg/bar", barFile) + + c := Collector{ + TContext: types.NewContext(), + Info: f.Info, + Instances: &PackageInstanceSets{}, + } + c.Scan(barPkg, barFile) + c.Scan(fooPkg, fooFile) + + inst := func(pkg *types.Package, name string, tArg types.BasicKind) Instance { + return Instance{ + Object: srctesting.LookupObj(pkg, name), + TArgs: []types.Type{types.Typ[tArg]}, + } + } + + wantFooInstances := []Instance{ + inst(fooPkg, "F", types.Int16), // Found in "pkg/foo". + inst(fooPkg, "F", types.Int8), + inst(fooPkg, "X", types.Int16), // Found due to F[int16] found in "pkg/foo". + inst(fooPkg, "X", types.Int8), + } + gotFooInstances := c.Instances.Pkg(fooPkg).Values() + if diff := cmp.Diff(wantFooInstances, gotFooInstances, instanceOpts()); diff != "" { + t.Errorf("Instances from pkg/foo contain diff (-want,+got):\n%s", diff) + } + + wantBarInstances := []Instance{ + inst(barPkg, "FProxy", types.Int16), + } + gotBarInstances := c.Instances.Pkg(barPkg).Values() + if diff := cmp.Diff(wantBarInstances, gotBarInstances, instanceOpts()); diff != "" { + t.Errorf("Instances from pkg/foo contain diff (-want,+got):\n%s", diff) + } +} + func TestResolver_SubstituteSelection(t *testing.T) { tests := []struct { descr string diff --git a/compiler/internal/typeparams/instance.go b/compiler/internal/typeparams/instance.go index a64e3be8a..87240c077 100644 --- a/compiler/internal/typeparams/instance.go +++ b/compiler/internal/typeparams/instance.go @@ -137,3 +137,33 @@ func (iset *InstanceSet) exhausted() bool { return len(iset.values) <= iset.unpr func (iset *InstanceSet) Values() []Instance { return iset.values } + +// PackageInstanceSets stores an InstanceSet for each package in a program, keyed +// by import path. +type PackageInstanceSets map[string]*InstanceSet + +// Pkg returns InstanceSet for objects defined in the given package. +func (i PackageInstanceSets) Pkg(pkg *types.Package) *InstanceSet { + path := pkg.Path() + iset, ok := i[path] + if !ok { + iset = &InstanceSet{} + i[path] = iset + } + return iset +} + +// Add instances to the appropriate package's set. Automatically initialized +// new per-package sets upon a first encounter. +func (i PackageInstanceSets) Add(instances ...Instance) { + for _, inst := range instances { + i.Pkg(inst.Object.Pkg()).Add(inst) + } +} + +// ID returns a unique numeric identifier assigned to an instance in the set. +// +// See: InstanceSet.ID(). +func (i PackageInstanceSets) ID(inst Instance) int { + return i.Pkg(inst.Object.Pkg()).ID(inst) +} diff --git a/compiler/internal/typeparams/instance_test.go b/compiler/internal/typeparams/instance_test.go index a5273f883..154e95b82 100644 --- a/compiler/internal/typeparams/instance_test.go +++ b/compiler/internal/typeparams/instance_test.go @@ -204,3 +204,63 @@ func TestInstanceQueue(t *testing.T) { t.Errorf("set.Values() returned diff (-want,+got):\n%s", diff) } } + +func TestInstancesByPackage(t *testing.T) { + f := srctesting.New(t) + + const src1 = `package foo + type Typ[T any, V any] []T + ` + _, foo := f.Check("pkg/foo", f.Parse("foo.go", src1)) + + const src2 = `package bar + func Fun[U any, W any](x, y U) {} + ` + _, bar := f.Check("pkg/bar", f.Parse("bar.go", src2)) + + i1 := Instance{ + Object: foo.Scope().Lookup("Typ"), + TArgs: []types.Type{types.Typ[types.String], types.Typ[types.String]}, + } + i2 := Instance{ + Object: foo.Scope().Lookup("Typ"), + TArgs: []types.Type{types.Typ[types.Int], types.Typ[types.Int]}, + } + i3 := Instance{ + Object: bar.Scope().Lookup("Fun"), + TArgs: []types.Type{types.Typ[types.String], types.Typ[types.String]}, + } + + t.Run("Add", func(t *testing.T) { + instByPkg := PackageInstanceSets{} + instByPkg.Add(i1, i2, i3) + + gotFooInstances := instByPkg.Pkg(foo).Values() + wantFooInstances := []Instance{i1, i2} + if diff := cmp.Diff(wantFooInstances, gotFooInstances, instanceOpts()); diff != "" { + t.Errorf("instByPkg.Pkg(foo).Values() returned diff (-want,+got):\n%s", diff) + } + + gotValues := instByPkg.Pkg(bar).Values() + wantValues := []Instance{i3} + if diff := cmp.Diff(wantValues, gotValues, instanceOpts()); diff != "" { + t.Errorf("instByPkg.Pkg(bar).Values() returned diff (-want,+got):\n%s", diff) + } + }) + + t.Run("ID", func(t *testing.T) { + instByPkg := PackageInstanceSets{} + instByPkg.Add(i1, i2, i3) + + got := []int{ + instByPkg.ID(i1), + instByPkg.ID(i2), + instByPkg.ID(i3), + } + want := []int{0, 1, 0} + + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("unexpected instance IDs assigned (-want,+got):\n%s", diff) + } + }) +} diff --git a/compiler/package.go b/compiler/package.go index 112a16af7..1abadf275 100644 --- a/compiler/package.go +++ b/compiler/package.go @@ -39,7 +39,7 @@ type pkgContext struct { minify bool fileSet *token.FileSet errList ErrorList - instanceSet *typeparams.InstanceSet + instanceSet *typeparams.PackageInstanceSets } // funcContext maintains compiler context for a specific function (lexical scope?). @@ -215,11 +215,11 @@ func Compile(importPath string, files []*ast.File, fileSet *token.FileSet, impor tc := typeparams.Collector{ TContext: config.Context, Info: typesInfo, - Instances: &typeparams.InstanceSet{}, + Instances: &typeparams.PackageInstanceSets{}, } - tc.Scan(simplifiedFiles...) + tc.Scan(typesPkg, simplifiedFiles...) instancesByObj := map[types.Object][]typeparams.Instance{} - for _, inst := range tc.Instances.Values() { + for _, inst := range tc.Instances.Pkg(typesPkg).Values() { instancesByObj[inst.Object] = append(instancesByObj[inst.Object], inst) }