|
3 | 3 | module test_linalg
|
4 | 4 | use testdrive, only : new_unittest, unittest_type, error_type, check, skip_test
|
5 | 5 | use stdlib_kinds, only: sp, dp, xdp, qp, int8, int16, int32, int64
|
6 |
| - use stdlib_linalg, only: diag, eye, trace, outer_product |
| 6 | + use stdlib_linalg, only: diag, eye, trace, outer_product, cross_product |
7 | 7 |
|
8 | 8 | implicit none
|
9 | 9 |
|
@@ -57,7 +57,17 @@ contains
|
57 | 57 | new_unittest("outer_product_int8", test_outer_product_int8), &
|
58 | 58 | new_unittest("outer_product_int16", test_outer_product_int16), &
|
59 | 59 | new_unittest("outer_product_int32", test_outer_product_int32), &
|
60 |
| - new_unittest("outer_product_int64", test_outer_product_int64) & |
| 60 | + new_unittest("outer_product_int64", test_outer_product_int64), & |
| 61 | + new_unittest("cross_product_rsp", test_cross_product_rsp), & |
| 62 | + new_unittest("cross_product_rdp", test_cross_product_rdp), & |
| 63 | + new_unittest("cross_product_rqp", test_cross_product_rqp), & |
| 64 | + new_unittest("cross_product_csp", test_cross_product_csp), & |
| 65 | + new_unittest("cross_product_cdp", test_cross_product_cdp), & |
| 66 | + new_unittest("cross_product_cqp", test_cross_product_cqp), & |
| 67 | + new_unittest("cross_product_int8", test_cross_product_int8), & |
| 68 | + new_unittest("cross_product_int16", test_cross_product_int16), & |
| 69 | + new_unittest("cross_product_int32", test_cross_product_int32), & |
| 70 | + new_unittest("cross_product_int64", test_cross_product_int64) & |
61 | 71 | ]
|
62 | 72 |
|
63 | 73 | end subroutine collect_linalg
|
@@ -702,6 +712,163 @@ contains
|
702 | 712 | "all(abs(diff) == 0) failed.")
|
703 | 713 | end subroutine test_outer_product_int64
|
704 | 714 |
|
| 715 | + subroutine test_cross_product_int8(error) |
| 716 | + !> Error handling |
| 717 | + type(error_type), allocatable, intent(out) :: error |
| 718 | + |
| 719 | + integer, parameter :: n = 3 |
| 720 | + integer(int8) :: u(n), v(n), expected(n), diff(n) |
| 721 | + |
| 722 | + u = [1,0,0] |
| 723 | + v = [0,1,0] |
| 724 | + expected = [0,0,1] |
| 725 | + diff = expected - cross_product(u,v) |
| 726 | + call check(error, all(abs(diff) == 0), & |
| 727 | + "cross_product(u,v) == expected failed.") |
| 728 | + end subroutine test_cross_product_int8 |
| 729 | + |
| 730 | + subroutine test_cross_product_int16(error) |
| 731 | + !> Error handling |
| 732 | + type(error_type), allocatable, intent(out) :: error |
| 733 | + |
| 734 | + integer, parameter :: n = 3 |
| 735 | + integer(int16) :: u(n), v(n), expected(n), diff(n) |
| 736 | + |
| 737 | + u = [1,0,0] |
| 738 | + v = [0,1,0] |
| 739 | + expected = [0,0,1] |
| 740 | + diff = expected - cross_product(u,v) |
| 741 | + call check(error, all(abs(diff) == 0), & |
| 742 | + "cross_product(u,v) == expected failed.") |
| 743 | + end subroutine test_cross_product_int16 |
| 744 | + |
| 745 | + subroutine test_cross_product_int32(error) |
| 746 | + !> Error handling |
| 747 | + type(error_type), allocatable, intent(out) :: error |
| 748 | + |
| 749 | + integer, parameter :: n = 3 |
| 750 | + integer(int32) :: u(n), v(n), expected(n), diff(n) |
| 751 | + write(*,*) "test_cross_product_int32" |
| 752 | + u = [1,0,0] |
| 753 | + v = [0,1,0] |
| 754 | + expected = [0,0,1] |
| 755 | + diff = expected - cross_product(u,v) |
| 756 | + call check(error, all(abs(diff) == 0), & |
| 757 | + "cross_product(u,v) == expected failed.") |
| 758 | + end subroutine test_cross_product_int32 |
| 759 | + |
| 760 | + subroutine test_cross_product_int64(error) |
| 761 | + !> Error handling |
| 762 | + type(error_type), allocatable, intent(out) :: error |
| 763 | + |
| 764 | + integer, parameter :: n = 3 |
| 765 | + integer(int64) :: u(n), v(n), expected(n), diff(n) |
| 766 | + write(*,*) "test_cross_product_int64" |
| 767 | + u = [1,0,0] |
| 768 | + v = [0,1,0] |
| 769 | + expected = [0,0,1] |
| 770 | + diff = expected - cross_product(u,v) |
| 771 | + call check(error, all(abs(diff) == 0), & |
| 772 | + "cross_product(u,v) == expected failed.") |
| 773 | + end subroutine test_cross_product_int64 |
| 774 | + |
| 775 | + subroutine test_cross_product_rsp(error) |
| 776 | + !> Error handling |
| 777 | + type(error_type), allocatable, intent(out) :: error |
| 778 | + |
| 779 | + integer, parameter :: n = 3 |
| 780 | + real(sp) :: u(n), v(n), expected(n), diff(n) |
| 781 | + write(*,*) "test_cross_product_rsp" |
| 782 | + u = [1.1_sp,2.5_sp,2.4_sp] |
| 783 | + v = [0.5_sp,1.5_sp,2.5_sp] |
| 784 | + expected = [2.65_sp,-1.55_sp,0.4_sp] |
| 785 | + diff = expected - cross_product(u,v) |
| 786 | + call check(error, all(abs(diff) < sptol), & |
| 787 | + "all(abs(cross_product(u,v)-expected)) < sptol failed.") |
| 788 | + end subroutine test_cross_product_rsp |
| 789 | + |
| 790 | + subroutine test_cross_product_rdp(error) |
| 791 | + !> Error handling |
| 792 | + type(error_type), allocatable, intent(out) :: error |
| 793 | + |
| 794 | + integer, parameter :: n = 3 |
| 795 | + real(dp) :: u(n), v(n), expected(n), diff(n) |
| 796 | + write(*,*) "test_cross_product_rdp" |
| 797 | + u = [1.1_dp,2.5_dp,2.4_dp] |
| 798 | + v = [0.5_dp,1.5_dp,2.5_dp] |
| 799 | + expected = [2.65_dp,-1.55_dp,0.4_dp] |
| 800 | + diff = expected - cross_product(u,v) |
| 801 | + call check(error, all(abs(diff) < dptol), & |
| 802 | + "all(abs(cross_product(u,v)-expected)) < dptol failed.") |
| 803 | + end subroutine test_cross_product_rdp |
| 804 | + |
| 805 | + subroutine test_cross_product_rqp(error) |
| 806 | + !> Error handling |
| 807 | + type(error_type), allocatable, intent(out) :: error |
| 808 | + |
| 809 | +#:if WITH_QP |
| 810 | + integer, parameter :: n = 3 |
| 811 | + real(qp) :: u(n), v(n), expected(n), diff(n) |
| 812 | + write(*,*) "test_cross_product_rqp" |
| 813 | + u = [1.1_qp,2.5_qp,2.4_qp] |
| 814 | + v = [0.5_qp,1.5_qp,2.5_qp] |
| 815 | + expected = [2.65_qp,-1.55_qp,0.4_qp] |
| 816 | + diff = expected - cross_product(u,v) |
| 817 | + call check(error, all(abs(diff) < qptol), & |
| 818 | + "all(abs(cross_product(u,v)-expected)) < qptol failed.") |
| 819 | +#:else |
| 820 | + call skip_test(error, "Quadruple precision is not enabled") |
| 821 | +#:endif |
| 822 | + end subroutine test_cross_product_rqp |
| 823 | + |
| 824 | + subroutine test_cross_product_csp(error) |
| 825 | + !> Error handling |
| 826 | + type(error_type), allocatable, intent(out) :: error |
| 827 | + |
| 828 | + integer, parameter :: n = 3 |
| 829 | + complex(sp) :: u(n), v(n), expected(n), diff(n) |
| 830 | + write(*,*) "test_cross_product_csp" |
| 831 | + u = [cmplx(0,1,sp),cmplx(1,0,sp),cmplx(0,0,sp)] |
| 832 | + v = [cmplx(1,1,sp),cmplx(0,0,sp),cmplx(1,0,sp)] |
| 833 | + expected = [cmplx(1,0,sp),cmplx(0,-1,sp),cmplx(-1,-1,sp)] |
| 834 | + diff = expected - cross_product(u,v) |
| 835 | + call check(error, all(abs(diff) < sptol), & |
| 836 | + "all(abs(cross_product(u,v)-expected)) < sptol failed.") |
| 837 | + end subroutine test_cross_product_csp |
| 838 | + |
| 839 | + subroutine test_cross_product_cdp(error) |
| 840 | + !> Error handling |
| 841 | + type(error_type), allocatable, intent(out) :: error |
| 842 | + |
| 843 | + integer, parameter :: n = 3 |
| 844 | + complex(dp) :: u(n), v(n), expected(n), diff(n) |
| 845 | + write(*,*) "test_cross_product_cdp" |
| 846 | + u = [cmplx(0,1,dp),cmplx(1,0,dp),cmplx(0,0,dp)] |
| 847 | + v = [cmplx(1,1,dp),cmplx(0,0,dp),cmplx(1,0,dp)] |
| 848 | + expected = [cmplx(1,0,dp),cmplx(0,-1,dp),cmplx(-1,-1,dp)] |
| 849 | + diff = expected - cross_product(u,v) |
| 850 | + call check(error, all(abs(diff) < dptol), & |
| 851 | + "all(abs(cross_product(u,v)-expected)) < dptol failed.") |
| 852 | + end subroutine test_cross_product_cdp |
| 853 | + |
| 854 | + subroutine test_cross_product_cqp(error) |
| 855 | + !> Error handling |
| 856 | + type(error_type), allocatable, intent(out) :: error |
| 857 | + |
| 858 | +#:if WITH_QP |
| 859 | + integer, parameter :: n = 3 |
| 860 | + complex(qp) :: u(n), v(n), expected(n), diff(n) |
| 861 | + write(*,*) "test_cross_product_cqp" |
| 862 | + u = [cmplx(0,1,qp),cmplx(1,0,qp),cmplx(0,0,qp)] |
| 863 | + v = [cmplx(1,1,qp),cmplx(0,0,qp),cmplx(1,0,qp)] |
| 864 | + expected = [cmplx(1,0,qp),cmplx(0,-1,qp),cmplx(-1,-1,qp)] |
| 865 | + diff = expected - cross_product(u,v) |
| 866 | + call check(error, all(abs(diff) < qptol), & |
| 867 | + "all(abs(cross_product(u,v)-expected)) < qptol failed.") |
| 868 | +#:else |
| 869 | + call skip_test(error, "Quadruple precision is not enabled") |
| 870 | +#:endif |
| 871 | + end subroutine test_cross_product_cqp |
705 | 872 |
|
706 | 873 | pure recursive function catalan_number(n) result(value)
|
707 | 874 | integer, intent(in) :: n
|
|
0 commit comments