From b6ea7dba137a7321124b2b8ca059be00f4e76872 Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Fri, 9 Jan 2026 10:37:20 -0500 Subject: [PATCH 1/2] go_generics: add Go generics support This is needed to allow templates to produce generic types. --- tools/go_generics/globals/globals_visitor.go | 58 ++++++++++++++++++++ tools/go_generics/tests/typeparams/BUILD | 19 +++++++ tools/go_generics/tests/typeparams/input.go | 46 ++++++++++++++++ tools/go_generics/tests/typeparams/output.go | 30 ++++++++++ 4 files changed, 153 insertions(+) create mode 100644 tools/go_generics/tests/typeparams/BUILD create mode 100644 tools/go_generics/tests/typeparams/input.go create mode 100644 tools/go_generics/tests/typeparams/output.go diff --git a/tools/go_generics/globals/globals_visitor.go b/tools/go_generics/globals/globals_visitor.go index 43a733b26c..8cd90c2b27 100644 --- a/tools/go_generics/globals/globals_visitor.go +++ b/tools/go_generics/globals/globals_visitor.go @@ -98,16 +98,30 @@ func (v *globalsVisitor) visitType(ge ast.Expr) { case *ast.ArrayType: v.visitExpr(e.Len) v.visitType(e.Elt) + case *ast.BinaryExpr: + v.visitType(e.X) + v.visitType(e.Y) + case *ast.UnaryExpr: + v.visitType(e.X) case *ast.MapType: v.visitType(e.Key) v.visitType(e.Value) case *ast.StructType: v.visitFields(e.Fields, KindUnknown) case *ast.FuncType: + v.visitFields(e.TypeParams, KindType) v.visitFields(e.Params, KindUnknown) v.visitFields(e.Results, KindUnknown) case *ast.InterfaceType: v.visitFields(e.Methods, KindUnknown) + case *ast.IndexExpr: + v.visitType(e.X) + v.visitType(e.Index) + case *ast.IndexListExpr: + v.visitType(e.X) + for _, index := range e.Indices { + v.visitType(index) + } default: v.unexpected(ge.Pos()) } @@ -151,7 +165,10 @@ func (v *globalsVisitor) visitGenDecl(d *ast.GenDecl) { if v.scope.isGlobal() { v.f(s.Name, KindType) } + v.pushScope() + v.visitFields(s.TypeParams, KindType) v.visitType(s.Type) + v.popScope() } case token.CONST, token.VAR: kind := KindConst @@ -193,6 +210,12 @@ func (v *globalsVisitor) isViableType(expr ast.Expr) bool { s := v.scope.deepLookup(e.Name) return s == nil || s.kind == KindType + case *ast.IndexExpr: + return v.isViableType(e.X) + + case *ast.IndexListExpr: + return v.isViableType(e.X) + case *ast.ChanType, *ast.ArrayType, *ast.MapType, *ast.StructType, *ast.FuncType, *ast.InterfaceType, *ast.Ellipsis: // This covers the following cases: // 1. ChanType: @@ -303,6 +326,11 @@ func (v *globalsVisitor) visitExpr(ge ast.Expr) { case *ast.IndexExpr: v.visitExpr(e.X) v.visitExpr(e.Index) + case *ast.IndexListExpr: + v.visitExpr(e.X) + for _, index := range e.Indices { + v.visitExpr(index) + } case *ast.KeyValueExpr: v.visitExpr(e.Value) @@ -498,6 +526,34 @@ func (v *globalsVisitor) visitBlockStmt(s *ast.BlockStmt) { v.popScope() } +func (v *globalsVisitor) addTypeParamsFromRecv(recv *ast.FieldList) { + if recv == nil { + return + } + for _, f := range recv.List { + v.addTypeParamsFromRecvExpr(f.Type) + } +} + +func (v *globalsVisitor) addTypeParamsFromRecvExpr(expr ast.Expr) { + switch e := expr.(type) { + case *ast.ParenExpr: + v.addTypeParamsFromRecvExpr(e.X) + case *ast.StarExpr: + v.addTypeParamsFromRecvExpr(e.X) + case *ast.IndexExpr: + if id := GetIdent(e.Index); id != nil { + v.scope.add(id.Name, KindType, id.Pos()) + } + case *ast.IndexListExpr: + for _, index := range e.Indices { + if id := GetIdent(index); id != nil { + v.scope.add(id.Name, KindType, id.Pos()) + } + } + } +} + // visitFuncDecl is called when a function or method declaration is encountered. // it creates a new scope for the function [optional] receiver, parameters and // results, and visits all children nodes. @@ -508,7 +564,9 @@ func (v *globalsVisitor) visitFuncDecl(d *ast.FuncDecl) { } v.pushScope() + v.addTypeParamsFromRecv(d.Recv) v.visitFields(d.Recv, KindReceiver) + v.visitFields(d.Type.TypeParams, KindType) v.visitFields(d.Type.Params, KindParameter) v.visitFields(d.Type.Results, KindResult) if d.Body != nil { diff --git a/tools/go_generics/tests/typeparams/BUILD b/tools/go_generics/tests/typeparams/BUILD new file mode 100644 index 0000000000..5a2bb028cb --- /dev/null +++ b/tools/go_generics/tests/typeparams/BUILD @@ -0,0 +1,19 @@ +load("//tools/go_generics/tests:defs.bzl", "go_generics_test") + +package(default_applicable_licenses = ["//:license"]) + +go_generics_test( + name = "typeparams", + inputs = ["input.go"], + output = "output.go", + package = "tests", + types = { + "T": "MyT", + }, +) + +# @unused +glaze_ignore = [ + "input.go", + "output.go", +] diff --git a/tools/go_generics/tests/typeparams/input.go b/tools/go_generics/tests/typeparams/input.go new file mode 100644 index 0000000000..0791c37d90 --- /dev/null +++ b/tools/go_generics/tests/typeparams/input.go @@ -0,0 +1,46 @@ +// Copyright 2026 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +type Number interface { + ~int | ~int64 +} + +type T struct{} + +type Box[T any] struct { + v T +} + +func (b Box[T]) Get() T { + return b.v +} + +func Use[T Number](v T) Box[T] { + return Box[T]{v: v} +} + +func UseGlobal(x T) T { + return x +} + +type Pair[A, B any] struct { + first A + second B +} + +var _ = Pair[int, string]{} +var _ = Box[int]{} +var _ = Use[int](1) diff --git a/tools/go_generics/tests/typeparams/output.go b/tools/go_generics/tests/typeparams/output.go new file mode 100644 index 0000000000..5609106998 --- /dev/null +++ b/tools/go_generics/tests/typeparams/output.go @@ -0,0 +1,30 @@ +package tests + +type Number interface { + ~int | ~int64 +} + +type Box[T any] struct { + v T +} + +func (b Box[T]) Get() T { + return b.v +} + +func Use[T Number](v T) Box[T] { + return Box[T]{v: v} +} + +func UseGlobal(x MyT) MyT { + return x +} + +type Pair[A, B any] struct { + first A + second B +} + +var _ = Pair[int, string]{} +var _ = Box[int]{} +var _ = Use[int](1) From 9e0fc107b6e8e27bdfaa9208354cdedc4d15d89b Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Thu, 8 Jan 2026 20:55:43 -0500 Subject: [PATCH 2/2] atomicptr: use atomic.Pointer Removes the use of "unsafe". --- pkg/sentry/kernel/BUILD | 14 +++++++------- pkg/sentry/kernel/auth/BUILD | 4 ++-- pkg/sentry/kernel/futex/BUILD | 4 ++-- pkg/sentry/platform/kvm/BUILD | 4 ++-- pkg/sync/atomicptr/BUILD | 6 +++--- ...ic_atomicptr_unsafe.go => generic_atomicptr.go} | 13 ++++--------- 6 files changed, 20 insertions(+), 25 deletions(-) rename pkg/sync/atomicptr/{generic_atomicptr_unsafe.go => generic_atomicptr.go} (73%) diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 3f720a7e94..95b376cbec 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -9,7 +9,7 @@ package( go_template_instance( name = "atomicptr_bucket_slice", - out = "atomicptr_bucket_slice_unsafe.go", + out = "atomicptr_bucket_slice.go", package = "kernel", prefix = "descriptorBucketSlice", template = "//pkg/sync/atomicptr:generic_atomicptr", @@ -20,7 +20,7 @@ go_template_instance( go_template_instance( name = "atomicptr_bucket", - out = "atomicptr_bucket_unsafe.go", + out = "atomicptr_bucket.go", package = "kernel", prefix = "descriptorBucket", template = "//pkg/sync/atomicptr:generic_atomicptr", @@ -31,7 +31,7 @@ go_template_instance( go_template_instance( name = "atomicptr_descriptor", - out = "atomicptr_descriptor_unsafe.go", + out = "atomicptr_descriptor.go", package = "kernel", prefix = "descriptor", template = "//pkg/sync/atomicptr:generic_atomicptr", @@ -42,7 +42,7 @@ go_template_instance( go_template_instance( name = "atomicptr_fscontext", - out = "atomicptr_fscontext_unsafe.go", + out = "atomicptr_fscontext.go", package = "kernel", prefix = "fsContext", template = "//pkg/sync/atomicptr:generic_atomicptr", @@ -240,9 +240,9 @@ go_library( name = "kernel", srcs = [ "aio.go", - "atomicptr_bucket_slice_unsafe.go", - "atomicptr_bucket_unsafe.go", - "atomicptr_descriptor_unsafe.go", + "atomicptr_bucket.go", + "atomicptr_bucket_slice.go", + "atomicptr_descriptor.go", "cgroup.go", "cgroup_mounts_mutex.go", "cgroup_mutex.go", diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD index 4a6760b651..0b55b2ae65 100644 --- a/pkg/sentry/kernel/auth/BUILD +++ b/pkg/sentry/kernel/auth/BUILD @@ -9,7 +9,7 @@ package( go_template_instance( name = "atomicptr_credentials", - out = "atomicptr_credentials_unsafe.go", + out = "atomicptr_credentials.go", package = "auth", suffix = "Credentials", template = "//pkg/sync/atomicptr:generic_atomicptr", @@ -71,7 +71,7 @@ declare_mutex( go_library( name = "auth", srcs = [ - "atomicptr_credentials_unsafe.go", + "atomicptr_credentials.go", "auth.go", "capability_set.go", "context.go", diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index 971dcf111e..5a22f9d851 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -19,7 +19,7 @@ declare_mutex( go_template_instance( name = "atomicptr_bucket", - out = "atomicptr_bucket_unsafe.go", + out = "atomicptr_bucket.go", package = "futex", suffix = "Bucket", template = "//pkg/sync/atomicptr:generic_atomicptr", @@ -43,7 +43,7 @@ go_template_instance( go_library( name = "futex", srcs = [ - "atomicptr_bucket_unsafe.go", + "atomicptr_bucket.go", "futex.go", "futex_mutex.go", "waiter_list.go", diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index e38d21efd0..91159071ce 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -8,7 +8,7 @@ package( go_template_instance( name = "atomicptr_machine", - out = "atomicptr_machine_unsafe.go", + out = "atomicptr_machine.go", package = "kvm", prefix = "machine", template = "//pkg/sync/atomicptr:generic_atomicptr", @@ -30,7 +30,7 @@ go_library( "address_space.go", "address_space_amd64.go", "address_space_arm64.go", - "atomicptr_machine_unsafe.go", + "atomicptr_machine.go", "bluepill.go", "bluepill_allocator.go", "bluepill_amd64.go", diff --git a/pkg/sync/atomicptr/BUILD b/pkg/sync/atomicptr/BUILD index c6e5d7fedb..37268a32bd 100644 --- a/pkg/sync/atomicptr/BUILD +++ b/pkg/sync/atomicptr/BUILD @@ -8,7 +8,7 @@ package( go_template( name = "generic_atomicptr", - srcs = ["generic_atomicptr_unsafe.go"], + srcs = ["generic_atomicptr.go"], types = [ "Value", ], @@ -17,7 +17,7 @@ go_template( go_template_instance( name = "atomicptr_int", - out = "atomicptr_int_unsafe.go", + out = "atomicptr_int.go", package = "atomicptr", suffix = "Int", template = ":generic_atomicptr", @@ -28,7 +28,7 @@ go_template_instance( go_library( name = "atomicptr", - srcs = ["atomicptr_int_unsafe.go"], + srcs = ["atomicptr_int.go"], ) go_test( diff --git a/pkg/sync/atomicptr/generic_atomicptr_unsafe.go b/pkg/sync/atomicptr/generic_atomicptr.go similarity index 73% rename from pkg/sync/atomicptr/generic_atomicptr_unsafe.go rename to pkg/sync/atomicptr/generic_atomicptr.go index dda08f4864..fee9315643 100644 --- a/pkg/sync/atomicptr/generic_atomicptr_unsafe.go +++ b/pkg/sync/atomicptr/generic_atomicptr.go @@ -11,7 +11,6 @@ package seqatomic import ( "context" "sync/atomic" - "unsafe" ) // Value is a required type parameter. @@ -20,13 +19,9 @@ type Value struct{} // An AtomicPtr is a pointer to a value of type Value that can be atomically // loaded and stored. The zero value of an AtomicPtr represents nil. // -// Note that copying AtomicPtr by value performs a non-atomic read of the -// stored pointer, which is unsafe if Store() can be called concurrently; in -// this case, do `dst.Store(src.Load())` instead. -// // +stateify savable type AtomicPtr struct { - ptr unsafe.Pointer `state:".(*Value)"` + ptr atomic.Pointer[Value] `state:".(*Value)"` } func (p *AtomicPtr) savePtr() *Value { @@ -42,15 +37,15 @@ func (p *AtomicPtr) loadPtr(_ context.Context, v *Value) { // //go:nosplit func (p *AtomicPtr) Load() *Value { - return (*Value)(atomic.LoadPointer(&p.ptr)) + return p.ptr.Load() } // Store sets the value returned by Load to x. func (p *AtomicPtr) Store(x *Value) { - atomic.StorePointer(&p.ptr, (unsafe.Pointer)(x)) + p.ptr.Store(x) } // Swap atomically stores `x` into *p and returns the previous *p value. func (p *AtomicPtr) Swap(x *Value) *Value { - return (*Value)(atomic.SwapPointer(&p.ptr, (unsafe.Pointer)(x))) + return p.ptr.Swap(x) }