Skip to content

Commit fba31e0

Browse files
authored
cross product of two vectors (#687)
* cross product of two vectors * fix test_linalg * fix cmakelist * more explicit test fail message
1 parent f092d06 commit fba31e0

File tree

6 files changed

+246
-2
lines changed

6 files changed

+246
-2
lines changed

doc/specs/stdlib_linalg.md

+30
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,36 @@ Returns a rank-2 array equal to `u v^T` (where `u, v` are considered column vect
160160
{!example/linalg/example_outer_product.f90!}
161161
```
162162

163+
## `cross_product` - Computes the cross product of two vectors
164+
165+
### Status
166+
167+
Experimental
168+
169+
### Description
170+
171+
Computes the cross product of two vectors
172+
173+
### Syntax
174+
175+
`c = [[stdlib_linalg(module):cross_product(interface)]](a, b)`
176+
177+
### Arguments
178+
179+
`a`: Shall be a rank-1 and size-3 array
180+
181+
`b`: Shall be a rank-1 and size-3 array
182+
183+
### Return value
184+
185+
Returns a rank-1 and size-3 array which is perpendicular to both `a` and `b`.
186+
187+
### Example
188+
189+
```fortran
190+
{!example/linalg/example_cross_product.f90!}
191+
```
192+
163193
## `is_square` - Checks if a matrix is square
164194

165195
### Status
+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
program demo_cross_product
2+
use stdlib_linalg, only: cross_product
3+
implicit none
4+
real :: a(3), b(3), c(3)
5+
a = [1., 0., 0.]
6+
b = [0., 1., 0.]
7+
c = cross_product(a, b)
8+
!c = [0., 0., 1.]
9+
end program demo_cross_product

src/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ set(fppFiles
2222
stdlib_linalg.fypp
2323
stdlib_linalg_diag.fypp
2424
stdlib_linalg_outer_product.fypp
25+
stdlib_linalg_cross_product.fypp
2526
stdlib_optval.fypp
2627
stdlib_selection.fypp
2728
stdlib_sorting.fypp

src/stdlib_linalg.fypp

+16
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ module stdlib_linalg
1414
public :: eye
1515
public :: trace
1616
public :: outer_product
17+
public :: cross_product
1718
public :: is_square
1819
public :: is_diagonal
1920
public :: is_symmetric
@@ -93,6 +94,21 @@ module stdlib_linalg
9394
end interface outer_product
9495

9596

97+
! Cross product (of two vectors)
98+
interface cross_product
99+
!! version: experimental
100+
!!
101+
!! Computes the cross product of two vectors, returning a rank-1 and size-3 array
102+
!! ([Specification](../page/specs/stdlib_linalg.html#cross_product-computes-the-cross-product-of-two-3-d-vectors))
103+
#:for k1, t1 in RCI_KINDS_TYPES
104+
pure module function cross_product_${t1[0]}$${k1}$(a, b) result(res)
105+
${t1}$, intent(in) :: a(3), b(3)
106+
${t1}$ :: res(3)
107+
end function cross_product_${t1[0]}$${k1}$
108+
#:endfor
109+
end interface cross_product
110+
111+
96112
! Check for squareness
97113
interface is_square
98114
!! version: experimental

src/stdlib_linalg_cross_product.fypp

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#:include "common.fypp"
2+
#:set RCI_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES + INT_KINDS_TYPES
3+
submodule (stdlib_linalg) stdlib_linalg_cross_product
4+
5+
implicit none
6+
7+
contains
8+
9+
#:for k1, t1 in RCI_KINDS_TYPES
10+
pure module function cross_product_${t1[0]}$${k1}$(a, b) result(res)
11+
${t1}$, intent(in) :: a(3), b(3)
12+
${t1}$ :: res(3)
13+
14+
res(1) = a(2) * b(3) - a(3) * b(2)
15+
res(2) = a(3) * b(1) - a(1) * b(3)
16+
res(3) = a(1) * b(2) - a(2) * b(1)
17+
18+
end function cross_product_${t1[0]}$${k1}$
19+
#:endfor
20+
21+
end submodule

test/linalg/test_linalg.fypp

+169-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
module test_linalg
44
use testdrive, only : new_unittest, unittest_type, error_type, check, skip_test
55
use stdlib_kinds, only: sp, dp, xdp, qp, int8, int16, int32, int64
6-
use stdlib_linalg, only: diag, eye, trace, outer_product
6+
use stdlib_linalg, only: diag, eye, trace, outer_product, cross_product
77

88
implicit none
99

@@ -57,7 +57,17 @@ contains
5757
new_unittest("outer_product_int8", test_outer_product_int8), &
5858
new_unittest("outer_product_int16", test_outer_product_int16), &
5959
new_unittest("outer_product_int32", test_outer_product_int32), &
60-
new_unittest("outer_product_int64", test_outer_product_int64) &
60+
new_unittest("outer_product_int64", test_outer_product_int64), &
61+
new_unittest("cross_product_rsp", test_cross_product_rsp), &
62+
new_unittest("cross_product_rdp", test_cross_product_rdp), &
63+
new_unittest("cross_product_rqp", test_cross_product_rqp), &
64+
new_unittest("cross_product_csp", test_cross_product_csp), &
65+
new_unittest("cross_product_cdp", test_cross_product_cdp), &
66+
new_unittest("cross_product_cqp", test_cross_product_cqp), &
67+
new_unittest("cross_product_int8", test_cross_product_int8), &
68+
new_unittest("cross_product_int16", test_cross_product_int16), &
69+
new_unittest("cross_product_int32", test_cross_product_int32), &
70+
new_unittest("cross_product_int64", test_cross_product_int64) &
6171
]
6272

6373
end subroutine collect_linalg
@@ -702,6 +712,163 @@ contains
702712
"all(abs(diff) == 0) failed.")
703713
end subroutine test_outer_product_int64
704714

715+
subroutine test_cross_product_int8(error)
716+
!> Error handling
717+
type(error_type), allocatable, intent(out) :: error
718+
719+
integer, parameter :: n = 3
720+
integer(int8) :: u(n), v(n), expected(n), diff(n)
721+
722+
u = [1,0,0]
723+
v = [0,1,0]
724+
expected = [0,0,1]
725+
diff = expected - cross_product(u,v)
726+
call check(error, all(abs(diff) == 0), &
727+
"cross_product(u,v) == expected failed.")
728+
end subroutine test_cross_product_int8
729+
730+
subroutine test_cross_product_int16(error)
731+
!> Error handling
732+
type(error_type), allocatable, intent(out) :: error
733+
734+
integer, parameter :: n = 3
735+
integer(int16) :: u(n), v(n), expected(n), diff(n)
736+
737+
u = [1,0,0]
738+
v = [0,1,0]
739+
expected = [0,0,1]
740+
diff = expected - cross_product(u,v)
741+
call check(error, all(abs(diff) == 0), &
742+
"cross_product(u,v) == expected failed.")
743+
end subroutine test_cross_product_int16
744+
745+
subroutine test_cross_product_int32(error)
746+
!> Error handling
747+
type(error_type), allocatable, intent(out) :: error
748+
749+
integer, parameter :: n = 3
750+
integer(int32) :: u(n), v(n), expected(n), diff(n)
751+
write(*,*) "test_cross_product_int32"
752+
u = [1,0,0]
753+
v = [0,1,0]
754+
expected = [0,0,1]
755+
diff = expected - cross_product(u,v)
756+
call check(error, all(abs(diff) == 0), &
757+
"cross_product(u,v) == expected failed.")
758+
end subroutine test_cross_product_int32
759+
760+
subroutine test_cross_product_int64(error)
761+
!> Error handling
762+
type(error_type), allocatable, intent(out) :: error
763+
764+
integer, parameter :: n = 3
765+
integer(int64) :: u(n), v(n), expected(n), diff(n)
766+
write(*,*) "test_cross_product_int64"
767+
u = [1,0,0]
768+
v = [0,1,0]
769+
expected = [0,0,1]
770+
diff = expected - cross_product(u,v)
771+
call check(error, all(abs(diff) == 0), &
772+
"cross_product(u,v) == expected failed.")
773+
end subroutine test_cross_product_int64
774+
775+
subroutine test_cross_product_rsp(error)
776+
!> Error handling
777+
type(error_type), allocatable, intent(out) :: error
778+
779+
integer, parameter :: n = 3
780+
real(sp) :: u(n), v(n), expected(n), diff(n)
781+
write(*,*) "test_cross_product_rsp"
782+
u = [1.1_sp,2.5_sp,2.4_sp]
783+
v = [0.5_sp,1.5_sp,2.5_sp]
784+
expected = [2.65_sp,-1.55_sp,0.4_sp]
785+
diff = expected - cross_product(u,v)
786+
call check(error, all(abs(diff) < sptol), &
787+
"all(abs(cross_product(u,v)-expected)) < sptol failed.")
788+
end subroutine test_cross_product_rsp
789+
790+
subroutine test_cross_product_rdp(error)
791+
!> Error handling
792+
type(error_type), allocatable, intent(out) :: error
793+
794+
integer, parameter :: n = 3
795+
real(dp) :: u(n), v(n), expected(n), diff(n)
796+
write(*,*) "test_cross_product_rdp"
797+
u = [1.1_dp,2.5_dp,2.4_dp]
798+
v = [0.5_dp,1.5_dp,2.5_dp]
799+
expected = [2.65_dp,-1.55_dp,0.4_dp]
800+
diff = expected - cross_product(u,v)
801+
call check(error, all(abs(diff) < dptol), &
802+
"all(abs(cross_product(u,v)-expected)) < dptol failed.")
803+
end subroutine test_cross_product_rdp
804+
805+
subroutine test_cross_product_rqp(error)
806+
!> Error handling
807+
type(error_type), allocatable, intent(out) :: error
808+
809+
#:if WITH_QP
810+
integer, parameter :: n = 3
811+
real(qp) :: u(n), v(n), expected(n), diff(n)
812+
write(*,*) "test_cross_product_rqp"
813+
u = [1.1_qp,2.5_qp,2.4_qp]
814+
v = [0.5_qp,1.5_qp,2.5_qp]
815+
expected = [2.65_qp,-1.55_qp,0.4_qp]
816+
diff = expected - cross_product(u,v)
817+
call check(error, all(abs(diff) < qptol), &
818+
"all(abs(cross_product(u,v)-expected)) < qptol failed.")
819+
#:else
820+
call skip_test(error, "Quadruple precision is not enabled")
821+
#:endif
822+
end subroutine test_cross_product_rqp
823+
824+
subroutine test_cross_product_csp(error)
825+
!> Error handling
826+
type(error_type), allocatable, intent(out) :: error
827+
828+
integer, parameter :: n = 3
829+
complex(sp) :: u(n), v(n), expected(n), diff(n)
830+
write(*,*) "test_cross_product_csp"
831+
u = [cmplx(0,1,sp),cmplx(1,0,sp),cmplx(0,0,sp)]
832+
v = [cmplx(1,1,sp),cmplx(0,0,sp),cmplx(1,0,sp)]
833+
expected = [cmplx(1,0,sp),cmplx(0,-1,sp),cmplx(-1,-1,sp)]
834+
diff = expected - cross_product(u,v)
835+
call check(error, all(abs(diff) < sptol), &
836+
"all(abs(cross_product(u,v)-expected)) < sptol failed.")
837+
end subroutine test_cross_product_csp
838+
839+
subroutine test_cross_product_cdp(error)
840+
!> Error handling
841+
type(error_type), allocatable, intent(out) :: error
842+
843+
integer, parameter :: n = 3
844+
complex(dp) :: u(n), v(n), expected(n), diff(n)
845+
write(*,*) "test_cross_product_cdp"
846+
u = [cmplx(0,1,dp),cmplx(1,0,dp),cmplx(0,0,dp)]
847+
v = [cmplx(1,1,dp),cmplx(0,0,dp),cmplx(1,0,dp)]
848+
expected = [cmplx(1,0,dp),cmplx(0,-1,dp),cmplx(-1,-1,dp)]
849+
diff = expected - cross_product(u,v)
850+
call check(error, all(abs(diff) < dptol), &
851+
"all(abs(cross_product(u,v)-expected)) < dptol failed.")
852+
end subroutine test_cross_product_cdp
853+
854+
subroutine test_cross_product_cqp(error)
855+
!> Error handling
856+
type(error_type), allocatable, intent(out) :: error
857+
858+
#:if WITH_QP
859+
integer, parameter :: n = 3
860+
complex(qp) :: u(n), v(n), expected(n), diff(n)
861+
write(*,*) "test_cross_product_cqp"
862+
u = [cmplx(0,1,qp),cmplx(1,0,qp),cmplx(0,0,qp)]
863+
v = [cmplx(1,1,qp),cmplx(0,0,qp),cmplx(1,0,qp)]
864+
expected = [cmplx(1,0,qp),cmplx(0,-1,qp),cmplx(-1,-1,qp)]
865+
diff = expected - cross_product(u,v)
866+
call check(error, all(abs(diff) < qptol), &
867+
"all(abs(cross_product(u,v)-expected)) < qptol failed.")
868+
#:else
869+
call skip_test(error, "Quadruple precision is not enabled")
870+
#:endif
871+
end subroutine test_cross_product_cqp
705872

706873
pure recursive function catalan_number(n) result(value)
707874
integer, intent(in) :: n

0 commit comments

Comments
 (0)