diff --git a/src/internal/reflectlite/value.go b/src/internal/reflectlite/value.go index 0236035c50..b8ec96d55f 100644 --- a/src/internal/reflectlite/value.go +++ b/src/internal/reflectlite/value.go @@ -86,6 +86,64 @@ func (v Value) Interface() interface{} { return valueInterfaceUnsafe(v) } +func TypeAssert[T any](v Value) (T, bool) { + if v.typecode == nil { + panic("reflect.TypeAssert: zero Value") + } + if !v.isExported() { + // Do not allow access to unexported values via TypeAssert, + // because they might be pointers that should not be + // writable or methods or function that should not be callable. + panic("reflect.TypeAssert: cannot return value obtained from unexported field or method") + } + + typ := TypeFor[T]() + + // If v is an interface, return the element inside the interface. + // + // T is a concrete type and v is an interface. For example: + // + // var v any = int(1) + // val := ValueOf(&v).Elem() + // TypeAssert[int](val) == val.Interface().(int) + // + // T is a interface and v is a non-nil interface value. For example: + // + // var v any = &someError{} + // val := ValueOf(&v).Elem() + // TypeAssert[error](val) == val.Interface().(error) + // + // T is a interface and v is a nil interface value. For example: + // + // var v error = nil + // val := ValueOf(&v).Elem() + // TypeAssert[error](val) == val.Interface().(error) + if v.Kind() == Interface { + val, ok := valueInterfaceUnsafe(v).(T) + return val, ok + } + + // If T is an interface and v is a concrete type. For example: + // + // TypeAssert[any](ValueOf(1)) == ValueOf(1).Interface().(any) + // TypeAssert[error](ValueOf(&someError{})) == ValueOf(&someError{}).Interface().(error) + if typ.Kind() == Interface { + val, ok := valueInterfaceUnsafe(v).(T) + return val, ok + } + + // Both v and T must be concrete types. + // The only way for an type-assertion to match is if the types are equal. + if typ != v.typecode { + var zero T + return zero, false + } + if !v.isIndirect() { + return *(*T)(unsafe.Pointer(&v.value)), true + } + return *(*T)(v.value), true +} + // valueInterfaceUnsafe is used by the runtime to hash map keys. It should not // be subject to the isExported check. func valueInterfaceUnsafe(v Value) interface{} { diff --git a/src/reflect/value.go b/src/reflect/value.go index a7e4787c67..cf6952a770 100644 --- a/src/reflect/value.go +++ b/src/reflect/value.go @@ -17,6 +17,10 @@ func ValueOf(i interface{}) Value { return Value{reflectlite.ValueOf(i)} } +func TypeAssert[T any](v Value) (T, bool) { + return reflectlite.TypeAssert[T](v.Value) +} + func (v Value) Type() Type { return toType(v.Value.Type()) } diff --git a/src/reflect/value_test.go b/src/reflect/value_test.go index 4df9db4d19..4b818100b5 100644 --- a/src/reflect/value_test.go +++ b/src/reflect/value_test.go @@ -3,7 +3,9 @@ package reflect_test import ( "bytes" "encoding/base64" + "fmt" . "reflect" + "runtime" "slices" "sort" "strings" @@ -869,3 +871,78 @@ func equal[T comparable](a, b []T) bool { } return true } + +func TestTypeAssert(t *testing.T) { + testTypeAssert(t, int(123456789), int(123456789), true) + testTypeAssert(t, int(-123456789), int(-123456789), true) + testTypeAssert(t, int32(123456789), int32(123456789), true) + testTypeAssert(t, int8(-123), int8(-123), true) + testTypeAssert(t, [2]int{1234, -5678}, [2]int{1234, -5678}, true) + testTypeAssert(t, "test value", "test value", true) + testTypeAssert(t, any("test value"), any("test value"), true) + + v := 123456789 + testTypeAssert(t, &v, &v, true) + + testTypeAssert(t, int(123), uint(0), false) + + testTypeAssert[any](t, 1, 1, true) + testTypeAssert[fmt.Stringer](t, 1, nil, false) + + vv := testTypeWithMethod{"test"} + testTypeAssert[any](t, vv, vv, true) + testTypeAssert[any](t, &vv, &vv, true) + testTypeAssert[fmt.Stringer](t, vv, vv, true) + testTypeAssert[fmt.Stringer](t, &vv, &vv, true) + testTypeAssert[interface{ A() }](t, vv, nil, false) + testTypeAssert[interface{ A() }](t, &vv, nil, false) + testTypeAssert(t, any(vv), any(vv), true) + testTypeAssert(t, fmt.Stringer(vv), fmt.Stringer(vv), true) + + testTypeAssert(t, fmt.Stringer(vv), any(vv), true) + testTypeAssert(t, any(vv), fmt.Stringer(vv), true) + testTypeAssert(t, fmt.Stringer(vv), interface{ M() }(vv), true) + testTypeAssert(t, interface{ M() }(vv), fmt.Stringer(vv), true) + + testTypeAssert(t, any(int(1)), int(1), true) + testTypeAssert(t, any(int(1)), byte(0), false) + testTypeAssert(t, fmt.Stringer(vv), vv, true) +} + +func testTypeAssert[T comparable, V any](t *testing.T, val V, wantVal T, wantOk bool) { + t.Helper() + + v, ok := TypeAssert[T](ValueOf(&val).Elem()) + if v != wantVal || ok != wantOk { + t.Errorf("TypeAssert[%v](%#v) = (%#v, %v); want = (%#v, %v)", TypeFor[T](), val, v, ok, wantVal, wantOk) + } + + // Additionally make sure that TypeAssert[T](v) behaves in the same way as v.Interface().(T). + v2, ok2 := ValueOf(&val).Elem().Interface().(T) + if v != v2 || ok != ok2 { + t.Errorf("reflect.ValueOf(%#v).Interface().(%v) = (%#v, %v); want = (%#v, %v)", val, TypeFor[T](), v2, ok2, v, ok) + } +} + +type testTypeWithMethod struct{ val string } + +func (v testTypeWithMethod) String() string { return v.val } +func (v testTypeWithMethod) M() {} + +func TestTypeAssertPanic(t *testing.T) { + if runtime.GOARCH == "wasm" { + t.Log("recover not supported") + return + } + + t.Run("zero val", func(t *testing.T) { + defer func() { recover() }() + TypeAssert[int](Value{}) + t.Fatalf("TypeAssert did not panic") + }) + t.Run("read only", func(t *testing.T) { + defer func() { recover() }() + TypeAssert[int](ValueOf(&testTypeWithMethod{}).FieldByName("val")) + t.Fatalf("TypeAssert did not panic") + }) +}