From 654e26ffd024accf6d43e84d8b9ba8f6fafeefc8 Mon Sep 17 00:00:00 2001 From: meetdoshi90 Date: Tue, 31 Aug 2021 16:10:21 +0530 Subject: [PATCH] Meet --- tf/edgeml_tf/tflite/bonsaiLayer.py | 59 ++++++ tf/edgeml_tf/trainer/bonsaiTrainer.py | 278 +++++++++++++++++++++++++- tf/edgeml_tf/utils.py | 5 + 3 files changed, 337 insertions(+), 5 deletions(-) diff --git a/tf/edgeml_tf/tflite/bonsaiLayer.py b/tf/edgeml_tf/tflite/bonsaiLayer.py index f8df1c373..addb8e511 100644 --- a/tf/edgeml_tf/tflite/bonsaiLayer.py +++ b/tf/edgeml_tf/tflite/bonsaiLayer.py @@ -72,6 +72,65 @@ def call(self, X): self.prediction = self.score return self.prediction + + def weighted_call(self, X): + ''' + Original bonsai learns a single, shallow sparse tree whose predictions for a + point x are given by + y(x) = Σ(k) I(k) (W(k).Transpose Zx) ◦ tanh(σV(k).Transpose Zx) + + Proposed function to include path smoothing. (Accuracy improvement for deeper trees) + To calculate a weighted average with respect to each node from the root node to + the leaf node. Each node thus contributes to the weighted average. + + This way the contribution should increase as depth increases and as the Bonsai tree model is balanced + the weight for each path will remain dependent on the number of classes and depth. + This way even if a leaf has few observations the non leaf nodes will correct the target distribution + by smoothing it out. Think of it as a heirarchical prior assignment. + This way we can still increase depth and remove the problem of overfitting trees. + y(x) = (Σ(k) I(k) hierarchical_prior(k) (W(k).Transpose Zx) ◦ tanh(V(k).Transpose Zx))/k + Path to which node to visit is determined by heirarchical_prior(k) + heirarchical_prior(k) = (n_classes * (1+σ)**depth)/n_classes + σ = sigma (positive number, [0.1,0.01] are good values) + Future work: Tree pruning can be added to the model to reduce the number of nodes, + Use CUDA or Dask for distributed training. + ''' + sigmaI = self.sigma + errmsg = "Dimension Mismatch, X is [_, self.dataDimension]" + assert (len(X.shape) == 2 and int(X.shape[1]) == self.dataDimension), errmsg + X_ = tf.divide(tf.matmul(self.Z, X, transpose_b=True), self.projectionDimension) + W_ = self.W[0:(self.numClasses)] + V_ = self.V[0:(self.numClasses)] + __nodeProb = [] + __nodeProb.append(1) + # Node count starts from 1 to avoid div by 0 + heirarchical_prior = (self.numClasses * (1 + sigmaI)) / self.numClasses + score_ = __nodeProb[0] * heirarchical_prior * tf.multiply(tf.matmul(W_, X_), tf.tanh(tf.matmul(V_, X_))) + for i in range(1, self.totalNodes): + W_ = self.W[i * self.numClasses:((i + 1) * self.numClasses)] + V_ = self.V[i * self.numClasses:((i + 1) * self.numClasses)] + T_ = tf.reshape(self.T[int(np.ceil(i / 2.0) - 1.0)], [-1, self.projectionDimension]) + prob = (1 + ((-1) ** (i + 1)) * tf.tanh(tf.multiply(sigmaI, tf.matmul(T_, X_)))) #Indicator function + prob = tf.divide(prob, 2.0) + prob = __nodeProb[int(np.ceil(i / 2.0) - 1.0)] * prob + __nodeProb.append(prob) + heirarchical_prior = (self.numClasses * (1 + sigmaI) ** np.log2([i+1])) / self.numClasses #Weighted prior + score_ += prob * heirarchical_prior * tf.multiply(tf.matmul(W_, X_), tf.tanh(tf.matmul(V_, X_))) + self.score = score_ + # Classification. + if (self.isRegression == False): + if self.numClasses > 2: + self.prediction = tf.argmax(tf.transpose(self.score), 1) + else: + self.prediction = tf.argmax( + tf.concat([tf.transpose(self.score), + 0 * tf.transpose(self.score)], 1), 1) + # Regression. + elif (self.isRegression == True): + # For regression , scores are the actual predictions, just return them. + self.prediction = self.score + return self.prediction + diff --git a/tf/edgeml_tf/trainer/bonsaiTrainer.py b/tf/edgeml_tf/trainer/bonsaiTrainer.py index 97dc460d2..7a9f57022 100644 --- a/tf/edgeml_tf/trainer/bonsaiTrainer.py +++ b/tf/edgeml_tf/trainer/bonsaiTrainer.py @@ -12,7 +12,7 @@ class BonsaiTrainer: def __init__(self, bonsaiObj, lW, lT, lV, lZ, sW, sT, sV, sZ, - learningRate, X, Y, useMCHLoss=False, outFile=None, regLoss='huber'): + learningRate, X, Y, useMCHLoss=False, outFile=None, regLoss='huber', errorReg='mae'): ''' bonsaiObj - Initialised Bonsai Object and Graph lW, lT, lV and lZ are regularisers to Bonsai Params @@ -64,7 +64,7 @@ def __init__(self, bonsaiObj, lW, lT, lV, lZ, sW, sT, sV, sZ, self.accuracy -> 'MAE' for Regression. self.accuracy -> 'Accuracy' for Classification. ''' - self.accuracy = self.accuracyGraph() + self.accuracy = self.accuracyGraph(errorReg=errorReg) self.prediction = self.bonsaiObj.getPrediction() if self.sW > 0.99 and self.sV > 0.99 and self.sZ > 0.99 and self.sT > 0.99: @@ -125,7 +125,7 @@ def trainGraph(self): return self.bonsaiObj.TrainStep - def accuracyGraph(self): + def accuracyGraph(self, errorReg='mae'): ''' Accuracy Graph to evaluate accuracy when needed ''' @@ -145,9 +145,14 @@ def accuracyGraph(self): elif (self.bonsaiObj.isRegression is True): # Accuracy for regression , in terms of mean absolute error. - self.accuracy = utils.mean_absolute_error(tf.reshape( - self.score, [-1, 1]), tf.reshape(self.Y, [-1, 1])) + if(errorReg == 'mse'): + self.accuracy = utils.mean_squared_error(tf.reshape( + self.score, [-1, 1]), tf.reshape(self.Y, [-1, 1])) + else: + self.accuracy = utils.mean_absolute_error(tf.reshape( + self.score, [-1, 1]), tf.reshape(self.Y, [-1, 1])) return self.accuracy + def hardThrsd(self): ''' @@ -558,3 +563,266 @@ def train(self, batchSize, totalEpochs, sess, if self.outFile is not sys.stdout: self.outFile.close() + + def weighted_stochastic_train(self, batchSize, totalEpochs, sess, Xtrain, Xtest, Ytrain, Ytest, dataDir, currDir): + ''' + Changes comapared to original + -Stochastic gradient descent implementation to reduce training time for weighted Bonsai Trees. + -For Training regression Mean squared error used. + -Weighted bonsai node predictor used + ''' + resultFile = open(dataDir + '/TFBonsaiResults.txt', 'a+') + numIters = Xtrain.shape[0] / batchSize + totalBatches = numIters * totalEpochs + bonsaiObjSigmaI = 1 + counter = 0 + + if self.bonsaiObj.numClasses > 2: + trimlevel = 15 + else: + trimlevel = 5 + ihtDone = 0 + if (self.bonsaiObj.isRegression is True): + maxTestAcc = 100000007 + else: + maxTestAcc = -10000 + if self.isDenseTraining is True: + ihtDone = 1 + bonsaiObjSigmaI = 1 + itersInPhase = 0 + + header = '*' * 20 + for i in range(totalEpochs): + print("\nEpoch Number: " + str(i), file=self.outFile) + + ''' + trainAcc -> For Regression, it is 'Mean Squared Error'. + trainAcc -> For Classification, it is 'Accuracy'. + ''' + trainAcc = 0.0 + trainLoss = 0.0 + + numIters = int(numIters) + for j in range(numIters): + + if counter == 0: + msg = " Dense Training Phase Started " + print("\n%s%s%s\n" % + (header, msg, header), file=self.outFile) + + # Updating the indicator sigma + if ((counter == 0) or (counter == int(totalBatches / 3.0)) or + (counter == int(2 * totalBatches / 3.0))) and (self.isDenseTraining is False): + bonsaiObjSigmaI = 1 + itersInPhase = 0 + + elif (itersInPhase % 100 == 0): + indices = np.random.choice(Xtrain.shape[0], 100) + batchX = Xtrain[indices, :] + batchY = Ytrain[indices, :] + batchY = np.reshape( + batchY, [-1, self.bonsaiObj.numClasses]) + + _feed_dict = {self.X: batchX} + Xcapeval = self.X_.eval(feed_dict=_feed_dict) + Teval = self.bonsaiObj.T.eval() + + sum_tr = 0.0 + for k in range(0, self.bonsaiObj.internalNodes): + sum_tr += (np.sum(np.abs(np.dot(Teval[k], Xcapeval)))) + + if(self.bonsaiObj.internalNodes > 0): + sum_tr /= (100 * self.bonsaiObj.internalNodes) + sum_tr = 0.1 / sum_tr + else: + sum_tr = 0.1 + sum_tr = min( + 1000, sum_tr * (2**(float(itersInPhase) / + (float(totalBatches) / 30.0)))) + + bonsaiObjSigmaI = sum_tr + + itersInPhase += 1 + batchX = Xtrain[j * batchSize:(j + 1) * batchSize] + batchY = Ytrain[j * batchSize:(j + 1) * batchSize] + batchY = np.reshape( + batchY, [-1, self.bonsaiObj.numClasses]) + + if self.bonsaiObj.numClasses > 2: + if self.useMCHLoss is True: + _feed_dict = {self.X: batchX, self.Y: batchY, + self.batch_th: batchY.shape[0], + self.sigmaI: bonsaiObjSigmaI} + else: + _feed_dict = {self.X: batchX, self.Y: batchY, + self.sigmaI: bonsaiObjSigmaI} + else: + _feed_dict = {self.X: batchX, self.Y: batchY, + self.sigmaI: bonsaiObjSigmaI} + + # Stochastic training + _, batchLoss, batchAcc = sess.run( + [1, self.loss, self.accuracy], + feed_dict=_feed_dict) + + # Classification. + if (self.bonsaiObj.isRegression is False): + trainAcc += batchAcc + trainLoss += batchLoss + # Regression. + else: + trainAcc += np.mean(batchAcc) + trainLoss += np.mean(batchLoss) + + # Training routine involving IHT and sparse retraining + if (counter >= int(totalBatches / 3.0) and + (counter < int(2 * totalBatches / 3.0)) and + counter % trimlevel == 0 and + self.isDenseTraining is False): + self.runHardThrsd(sess) + if ihtDone == 0: + msg = " IHT Phase Started " + print("\n%s%s%s\n" % + (header, msg, header), file=self.outFile) + ihtDone = 1 + elif ((ihtDone == 1 and counter >= int(totalBatches / 3.0) and + (counter < int(2 * totalBatches / 3.0)) and + counter % trimlevel != 0 and + self.isDenseTraining is False) or + (counter >= int(2 * totalBatches / 3.0) and + self.isDenseTraining is False)): + self.runSparseTraining(sess) + if counter == int(2 * totalBatches / 3.0): + msg = " Sparse Retraining Phase Started " + print("\n%s%s%s\n" % + (header, msg, header), file=self.outFile) + counter += 1 + try: + if (self.bonsaiObj.isRegression is True): + print("\nRegression Train Loss: " + str(trainLoss / numIters) + + "\nTraining MSE (Regression): " + + str(trainAcc / numIters), + file=self.outFile) + else: + print("\nClassification Train Loss: " + str(trainLoss / numIters) + + "\nTraining accuracy (Classification): " + + str(trainAcc / numIters), + file=self.outFile) + except: + continue + + oldSigmaI = bonsaiObjSigmaI + bonsaiObjSigmaI = 1e9 + + if self.bonsaiObj.numClasses > 2: + if self.useMCHLoss is True: + _feed_dict = {self.X: Xtest, self.Y: Ytest, + self.batch_th: Ytest.shape[0], + self.sigmaI: bonsaiObjSigmaI} + else: + _feed_dict = {self.X: Xtest, self.Y: Ytest, + self.sigmaI: bonsaiObjSigmaI} + else: + _feed_dict = {self.X: Xtest, self.Y: Ytest, + self.sigmaI: bonsaiObjSigmaI} + + # This helps in direct testing instead of extracting the model out + + testAcc, testLoss, regTestLoss, pred = sess.run( + [self.accuracy, self.loss, self.regLoss, self.prediction], feed_dict=_feed_dict) + + if ihtDone == 0: + if (self.bonsaiObj.isRegression is False): + maxTestAcc = -10000 + maxTestAccEpoch = i + elif (self.bonsaiObj.isRegression is True): + maxTestAcc = testAcc + maxTestAccEpoch = i + + else: + if (self.bonsaiObj.isRegression is False): + if maxTestAcc <= testAcc: + maxTestAccEpoch = i + maxTestAcc = testAcc + self.saveParams(currDir) + self.saveParamsForSeeDot(currDir) + elif (self.bonsaiObj.isRegression is True): + print("Minimum Training MSE : ", np.mean(maxTestAcc)) + if maxTestAcc >= testAcc: + # For regression , we're more interested in the minimum + # MSE. + maxTestAccEpoch = i + maxTestAcc = testAcc + self.saveParams(currDir) + self.saveParamsForSeeDot(currDir) + + if (self.bonsaiObj.isRegression is True): + print("Testing MSE %g" % np.mean(testAcc), file=self.outFile) + else: + print("Test accuracy %g" % np.mean(testAcc), file=self.outFile) + + if (self.bonsaiObj.isRegression is True): + testAcc = np.mean(testAcc) + else: + testAcc = testAcc + maxTestAcc = maxTestAcc + + print("MarginLoss + RegLoss: " + str(testLoss - regTestLoss) + + " + " + str(regTestLoss) + " = " + str(testLoss) + "\n", + file=self.outFile) + self.outFile.flush() + + bonsaiObjSigmaI = oldSigmaI + + # sigmaI has to be set to infinity to ensure + # only a single path is used in inference + bonsaiObjSigmaI = 1e9 + print("\nNon-Zero : " + str(self.getModelSize()[0]) + " Model Size: " + + str(float(self.getModelSize()[1]) / 1024.0) + " KB hasSparse: " + + str(self.getModelSize()[2]) + "\n", file=self.outFile) + + if (self.bonsaiObj.isRegression is True): + maxTestAcc = np.mean(maxTestAcc) + + if (self.bonsaiObj.isRegression is True): + print("For Regression, Minimum MSE at compressed" + + " model size(including early stopping): " + + str(maxTestAcc) + " at Epoch: " + + str(maxTestAccEpoch + 1) + "\nFinal Test" + + " MAE: " + str(testAcc), file=self.outFile) + + resultFile.write("MinTestMSE: " + str(maxTestAcc) + + " at Epoch(totalEpochs): " + + str(maxTestAccEpoch + 1) + + "(" + str(totalEpochs) + ")" + " ModelSize: " + + str(float(self.getModelSize()[1]) / 1024.0) + + " KB hasSparse: " + str(self.getModelSize()[2]) + + " Param Directory: " + + str(os.path.abspath(currDir)) + "\n") + + elif (self.bonsaiObj.isRegression is False): + print("For Classification, Maximum Test accuracy at compressed" + + " model size(including early stopping): " + + str(maxTestAcc) + " at Epoch: " + + str(maxTestAccEpoch + 1) + "\nFinal Test" + + " Accuracy: " + str(testAcc), file=self.outFile) + + resultFile.write("MaxTestAcc: " + str(maxTestAcc) + + " at Epoch(totalEpochs): " + + str(maxTestAccEpoch + 1) + + "(" + str(totalEpochs) + ")" + " ModelSize: " + + str(float(self.getModelSize()[1]) / 1024.0) + + " KB hasSparse: " + str(self.getModelSize()[2]) + + " Param Directory: " + + str(os.path.abspath(currDir)) + "\n") + print("The Model Directory: " + currDir + "\n") + + resultFile.close() + self.outFile.flush() + + if self.outFile is not sys.stdout: + self.outFile.close() + + + + diff --git a/tf/edgeml_tf/utils.py b/tf/edgeml_tf/utils.py index 64a4a0c4d..a9eb13adb 100644 --- a/tf/edgeml_tf/utils.py +++ b/tf/edgeml_tf/utils.py @@ -96,6 +96,11 @@ def mean_absolute_error(logits, label): ''' return tf.reduce_mean(tf.abs(tf.subtract(logits, label))) +def mean_squared_error(logits, label): + ''' + Function to compute the mean squared error. + ''' + return tf.reduce_mean(tf.square(tf.subtract(logits, label))) def hardThreshold(A, s): '''