@@ -32,7 +32,7 @@ class Matrix:
3232 def zeros (cls , rows : int , cols : int ):
3333 if rows < 1 and cols < 1 :
3434 raise ValueError ("invalid rows or columns provided" )
35- return [[0 ] * rows ] * cols
35+ return Matrix ( [[0 for _ in range ( rows )] for _ in range ( cols )])
3636
3737 def __init__ (self , data : list [list [int | float ]]) -> None :
3838 if data .__len__ ():
@@ -51,25 +51,54 @@ def size(self):
5151 return (self .data [0 ].__len__ (), self .data .__len__ ())
5252
5353 def get_row (self , index : int ):
54- return [ d [ index ] for d in self .data ]
54+ return self .data [ index ]
5555
5656 def get_column (self , index : int ):
57- return self .data [ index ]
57+ return [ d [ index ] for d in self .data ]
5858
5959 def __str__ (self ) -> str :
6060 if self .data .__len__ ():
61- return "| |"
62- return " \n " .join (
63- [ "| " + "" . join ([ f" { i :^5d } " for i in row ]) + " |" for row in self . data ]
64- )
61+ return "\n " . join (
62+ [ "| " + "" .join ([ f" { i :^5d } " for i in row ]) + " |" for row in self . data ]
63+ )
64+ return "| |"
6565
6666 def __mul__ (self , other : "Matrix" ):
6767 """
6868 This method overloads python's default multiplication operation between
6969 two Matrix objects so that we can easily perform `*` operation.
7070 """
71- # TODO
72- pass
71+ (self_rows , self_cols ) = self .size
72+ (other_rows , other_cols ) = other .size
73+ if self_rows != other_cols :
74+ raise ValueError ("Dimensions mismatch" )
75+
76+ # by for loop and matrix.zeros
77+ # this part is easy to understand but is a bit more computationally expensive
78+
79+ # result = Matrix.zeros(other_rows, self_cols)
80+ # for row in range(other_rows):
81+ # for col in range(self_cols):
82+ # result.data[row][col] = sum(
83+ # [a * b for (a, b) in zip(self.get_row(row), other.get_column(col))]
84+ # )
85+ # return result
86+
87+ # by comprehension quicker
88+ return Matrix (
89+ [
90+ [
91+ sum (
92+ [
93+ a * b
94+ for (a , b ) in zip (self .get_row (row ), other .get_column (col ))
95+ ]
96+ )
97+ for col in range (self_cols )
98+ ]
99+ for row in range (other_rows )
100+ ]
101+ )
73102
74103
75104if __name__ == "__main__" :
@@ -83,7 +112,13 @@ def __mul__(self, other: "Matrix"):
83112 # [5, 6],
84113 # ]
85114 # )
86- m1 = Matrix ([[1 , 2 ], [3 , 4 ]])
87- m2 = Matrix ([[2 , 3 ], [4 , 5 ]])
88115
89- print ("result is: \n " , m1 * m2 )
116+ # if we uncomment lines below, we get Dimensions mismatch exception
117+ # m1 = Matrix([[1, 2, 3], [4, 5, 6]])
118+ # m2 = Matrix([[2, 3, 4], [5, 6, 7]])
119+ # print("result is: \n", m1 * m2)
120+
121+ m1 = Matrix ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
122+ m2 = Matrix ([[2 , 3 ], [4 , 5 ], [6 , 7 ]])
123+ print ("m1 X m2 = :" , m1 * m2 , sep = "\n " )
124+ print ("m2 X m1 = :" , m2 * m1 , sep = "\n " )
0 commit comments