From 0e6dc3f253a3cbc6fff87b9f88406889bc1e0475 Mon Sep 17 00:00:00 2001 From: Andrew Stone Date: Mon, 14 Aug 2023 15:29:38 -0700 Subject: [PATCH] check: Print diff for anything that doesn't print nicely For a *int, printing only the addresses when they're different is useless. Now, anything that's non-printable will show a proper diff. --- check/assert_fmt.go | 48 ++++++++++++++++++++++++---------------- check/assert_fmt_test.go | 2 +- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/check/assert_fmt.go b/check/assert_fmt.go index 53001fc..85467fb 100644 --- a/check/assert_fmt.go +++ b/check/assert_fmt.go @@ -21,30 +21,40 @@ func fmtVals(g, e any) (string, string) { return gs, es } -func typeAndKind(v any) (reflect.Type, reflect.Kind) { - t := reflect.TypeOf(v) - k := t.Kind() - if k == reflect.Ptr { - t = t.Elem() - k = t.Kind() +// If fmt.Sprintf("%+v") generally produces a reasonable value +func isFormattable(v any) bool { + rt := reflect.TypeOf(v) + if rt == nil { + return true } - return t, k -} - -func diff(g, e any) string { - if g == nil || e == nil { - return "" + switch rt.Kind() { + case reflect.Bool: + case reflect.Int: + case reflect.Int8: + case reflect.Int16: + case reflect.Int32: + case reflect.Int64: + case reflect.Uint: + case reflect.Uint8: + case reflect.Uint16: + case reflect.Uint32: + case reflect.Uint64: + case reflect.Uintptr: + case reflect.Float32: + case reflect.Float64: + case reflect.Complex64: + case reflect.Complex128: + case reflect.String: + default: + return false } - gt, _ := typeAndKind(g) - et, ek := typeAndKind(e) - - if gt != et { - return "" - } + return true +} - if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array { +func diff(g, e any) string { + if isFormattable(g) && isFormattable(e) { return "" } diff --git a/check/assert_fmt_test.go b/check/assert_fmt_test.go index d5fa7e2..53c33fb 100644 --- a/check/assert_fmt_test.go +++ b/check/assert_fmt_test.go @@ -22,7 +22,7 @@ func TestAssertFmtDiffCoverage(t *testing.T) { c.Equal(diff(1, 1), "") v := new(int) - c.Equal(diff(1, v), "") + c.NotEqual(diff(1, v), "") c.NotEqual(diff([]string{"test"}, []string{"nope"}), "") }