11import torch
22
3+
34def accuracy (logits , labels ):
45 _ , indices = torch .max (logits , dim = 1 )
56 correct = torch .sum (indices == labels )
67 return correct .item () * 1.0 / len (labels )
78
9+
810# GPU | CPU
911def get_default_device ():
10-
1112 if torch .cuda .is_available ():
12- return torch .device (' cuda:0' )
13+ return torch .device (" cuda:0" )
1314 else :
14- return torch .device ('cpu' )
15+ return torch .device ("cpu" )
16+
1517
1618def to_default_device (data ):
17-
18- if isinstance (data ,(list ,tuple )):
19- return [to_default_device (x ,get_default_device ()) for x in data ]
20-
21- return data .to (get_default_device (),non_blocking = True )
19+ if isinstance (data , (list , tuple )):
20+ return [to_default_device (x , get_default_device ()) for x in data ]
21+
22+ return data .to (get_default_device (), non_blocking = True )
23+
24+
25+ def generate_train_mask (size : int , train_test_split : int ) -> list :
26+ cutoff = size * train_test_split
27+ return [1 if i < cutoff else 0 for i in range (size )]
28+
29+
30+ def generate_test_mask (size : int , train_test_split : int ) -> list :
31+ cutoff = size * train_test_split
32+ return [0 if i < cutoff else 1 for i in range (size )]
0 commit comments