Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Heirarchical prior. #247

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions tf/edgeml_tf/tflite/bonsaiLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,65 @@ def call(self, X):
self.prediction = self.score

return self.prediction

def weighted_call(self, X):
'''
Original bonsai learns a single, shallow sparse tree whose predictions for a
point x are given by
y(x) = Σ(k) I(k) (W(k).Transpose Zx) ◦ tanh(σV(k).Transpose Zx)

Proposed function to include path smoothing. (Accuracy improvement for deeper trees)
To calculate a weighted average with respect to each node from the root node to
the leaf node. Each node thus contributes to the weighted average.

This way the contribution should increase as depth increases and as the Bonsai tree model is balanced
the weight for each path will remain dependent on the number of classes and depth.
This way even if a leaf has few observations the non leaf nodes will correct the target distribution
by smoothing it out. Think of it as a heirarchical prior assignment.
This way we can still increase depth and remove the problem of overfitting trees.
y(x) = (Σ(k) I(k) hierarchical_prior(k) (W(k).Transpose Zx) ◦ tanh(V(k).Transpose Zx))/k
Path to which node to visit is determined by heirarchical_prior(k)
heirarchical_prior(k) = (n_classes * (1+σ)**depth)/n_classes
σ = sigma (positive number, [0.1,0.01] are good values)
Future work: Tree pruning can be added to the model to reduce the number of nodes,
Use CUDA or Dask for distributed training.
'''
sigmaI = self.sigma
errmsg = "Dimension Mismatch, X is [_, self.dataDimension]"
assert (len(X.shape) == 2 and int(X.shape[1]) == self.dataDimension), errmsg
X_ = tf.divide(tf.matmul(self.Z, X, transpose_b=True), self.projectionDimension)
W_ = self.W[0:(self.numClasses)]
V_ = self.V[0:(self.numClasses)]
__nodeProb = []
__nodeProb.append(1)
# Node count starts from 1 to avoid div by 0
heirarchical_prior = (self.numClasses * (1 + sigmaI)) / self.numClasses
score_ = __nodeProb[0] * heirarchical_prior * tf.multiply(tf.matmul(W_, X_), tf.tanh(tf.matmul(V_, X_)))
for i in range(1, self.totalNodes):
W_ = self.W[i * self.numClasses:((i + 1) * self.numClasses)]
V_ = self.V[i * self.numClasses:((i + 1) * self.numClasses)]
T_ = tf.reshape(self.T[int(np.ceil(i / 2.0) - 1.0)], [-1, self.projectionDimension])
prob = (1 + ((-1) ** (i + 1)) * tf.tanh(tf.multiply(sigmaI, tf.matmul(T_, X_)))) #Indicator function
prob = tf.divide(prob, 2.0)
prob = __nodeProb[int(np.ceil(i / 2.0) - 1.0)] * prob
__nodeProb.append(prob)
heirarchical_prior = (self.numClasses * (1 + sigmaI) ** np.log2([i+1])) / self.numClasses #Weighted prior
score_ += prob * heirarchical_prior * tf.multiply(tf.matmul(W_, X_), tf.tanh(tf.matmul(V_, X_)))
self.score = score_
# Classification.
if (self.isRegression == False):
if self.numClasses > 2:
self.prediction = tf.argmax(tf.transpose(self.score), 1)
else:
self.prediction = tf.argmax(
tf.concat([tf.transpose(self.score),
0 * tf.transpose(self.score)], 1), 1)
# Regression.
elif (self.isRegression == True):
# For regression , scores are the actual predictions, just return them.
self.prediction = self.score
return self.prediction




278 changes: 273 additions & 5 deletions tf/edgeml_tf/trainer/bonsaiTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class BonsaiTrainer:

def __init__(self, bonsaiObj, lW, lT, lV, lZ, sW, sT, sV, sZ,
learningRate, X, Y, useMCHLoss=False, outFile=None, regLoss='huber'):
learningRate, X, Y, useMCHLoss=False, outFile=None, regLoss='huber', errorReg='mae'):
'''
bonsaiObj - Initialised Bonsai Object and Graph
lW, lT, lV and lZ are regularisers to Bonsai Params
Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(self, bonsaiObj, lW, lT, lV, lZ, sW, sT, sV, sZ,
self.accuracy -> 'MAE' for Regression.
self.accuracy -> 'Accuracy' for Classification.
'''
self.accuracy = self.accuracyGraph()
self.accuracy = self.accuracyGraph(errorReg=errorReg)
self.prediction = self.bonsaiObj.getPrediction()

if self.sW > 0.99 and self.sV > 0.99 and self.sZ > 0.99 and self.sT > 0.99:
Expand Down Expand Up @@ -125,7 +125,7 @@ def trainGraph(self):

return self.bonsaiObj.TrainStep

def accuracyGraph(self):
def accuracyGraph(self, errorReg='mae'):
'''
Accuracy Graph to evaluate accuracy when needed
'''
Expand All @@ -145,9 +145,14 @@ def accuracyGraph(self):

elif (self.bonsaiObj.isRegression is True):
# Accuracy for regression , in terms of mean absolute error.
self.accuracy = utils.mean_absolute_error(tf.reshape(
self.score, [-1, 1]), tf.reshape(self.Y, [-1, 1]))
if(errorReg == 'mse'):
self.accuracy = utils.mean_squared_error(tf.reshape(
self.score, [-1, 1]), tf.reshape(self.Y, [-1, 1]))
else:
self.accuracy = utils.mean_absolute_error(tf.reshape(
self.score, [-1, 1]), tf.reshape(self.Y, [-1, 1]))
return self.accuracy


def hardThrsd(self):
'''
Expand Down Expand Up @@ -558,3 +563,266 @@ def train(self, batchSize, totalEpochs, sess,

if self.outFile is not sys.stdout:
self.outFile.close()

def weighted_stochastic_train(self, batchSize, totalEpochs, sess, Xtrain, Xtest, Ytrain, Ytest, dataDir, currDir):
'''
Changes comapared to original
-Stochastic gradient descent implementation to reduce training time for weighted Bonsai Trees.
-For Training regression Mean squared error used.
-Weighted bonsai node predictor used
'''
resultFile = open(dataDir + '/TFBonsaiResults.txt', 'a+')
numIters = Xtrain.shape[0] / batchSize
totalBatches = numIters * totalEpochs
bonsaiObjSigmaI = 1
counter = 0

if self.bonsaiObj.numClasses > 2:
trimlevel = 15
else:
trimlevel = 5
ihtDone = 0
if (self.bonsaiObj.isRegression is True):
maxTestAcc = 100000007
else:
maxTestAcc = -10000
if self.isDenseTraining is True:
ihtDone = 1
bonsaiObjSigmaI = 1
itersInPhase = 0

header = '*' * 20
for i in range(totalEpochs):
print("\nEpoch Number: " + str(i), file=self.outFile)

'''
trainAcc -> For Regression, it is 'Mean Squared Error'.
trainAcc -> For Classification, it is 'Accuracy'.
'''
trainAcc = 0.0
trainLoss = 0.0

numIters = int(numIters)
for j in range(numIters):

if counter == 0:
msg = " Dense Training Phase Started "
print("\n%s%s%s\n" %
(header, msg, header), file=self.outFile)

# Updating the indicator sigma
if ((counter == 0) or (counter == int(totalBatches / 3.0)) or
(counter == int(2 * totalBatches / 3.0))) and (self.isDenseTraining is False):
bonsaiObjSigmaI = 1
itersInPhase = 0

elif (itersInPhase % 100 == 0):
indices = np.random.choice(Xtrain.shape[0], 100)
batchX = Xtrain[indices, :]
batchY = Ytrain[indices, :]
batchY = np.reshape(
batchY, [-1, self.bonsaiObj.numClasses])

_feed_dict = {self.X: batchX}
Xcapeval = self.X_.eval(feed_dict=_feed_dict)
Teval = self.bonsaiObj.T.eval()

sum_tr = 0.0
for k in range(0, self.bonsaiObj.internalNodes):
sum_tr += (np.sum(np.abs(np.dot(Teval[k], Xcapeval))))

if(self.bonsaiObj.internalNodes > 0):
sum_tr /= (100 * self.bonsaiObj.internalNodes)
sum_tr = 0.1 / sum_tr
else:
sum_tr = 0.1
sum_tr = min(
1000, sum_tr * (2**(float(itersInPhase) /
(float(totalBatches) / 30.0))))

bonsaiObjSigmaI = sum_tr

itersInPhase += 1
batchX = Xtrain[j * batchSize:(j + 1) * batchSize]
batchY = Ytrain[j * batchSize:(j + 1) * batchSize]
batchY = np.reshape(
batchY, [-1, self.bonsaiObj.numClasses])

if self.bonsaiObj.numClasses > 2:
if self.useMCHLoss is True:
_feed_dict = {self.X: batchX, self.Y: batchY,
self.batch_th: batchY.shape[0],
self.sigmaI: bonsaiObjSigmaI}
else:
_feed_dict = {self.X: batchX, self.Y: batchY,
self.sigmaI: bonsaiObjSigmaI}
else:
_feed_dict = {self.X: batchX, self.Y: batchY,
self.sigmaI: bonsaiObjSigmaI}

# Stochastic training
_, batchLoss, batchAcc = sess.run(
[1, self.loss, self.accuracy],
feed_dict=_feed_dict)

# Classification.
if (self.bonsaiObj.isRegression is False):
trainAcc += batchAcc
trainLoss += batchLoss
# Regression.
else:
trainAcc += np.mean(batchAcc)
trainLoss += np.mean(batchLoss)

# Training routine involving IHT and sparse retraining
if (counter >= int(totalBatches / 3.0) and
(counter < int(2 * totalBatches / 3.0)) and
counter % trimlevel == 0 and
self.isDenseTraining is False):
self.runHardThrsd(sess)
if ihtDone == 0:
msg = " IHT Phase Started "
print("\n%s%s%s\n" %
(header, msg, header), file=self.outFile)
ihtDone = 1
elif ((ihtDone == 1 and counter >= int(totalBatches / 3.0) and
(counter < int(2 * totalBatches / 3.0)) and
counter % trimlevel != 0 and
self.isDenseTraining is False) or
(counter >= int(2 * totalBatches / 3.0) and
self.isDenseTraining is False)):
self.runSparseTraining(sess)
if counter == int(2 * totalBatches / 3.0):
msg = " Sparse Retraining Phase Started "
print("\n%s%s%s\n" %
(header, msg, header), file=self.outFile)
counter += 1
try:
if (self.bonsaiObj.isRegression is True):
print("\nRegression Train Loss: " + str(trainLoss / numIters) +
"\nTraining MSE (Regression): " +
str(trainAcc / numIters),
file=self.outFile)
else:
print("\nClassification Train Loss: " + str(trainLoss / numIters) +
"\nTraining accuracy (Classification): " +
str(trainAcc / numIters),
file=self.outFile)
except:
continue

oldSigmaI = bonsaiObjSigmaI
bonsaiObjSigmaI = 1e9

if self.bonsaiObj.numClasses > 2:
if self.useMCHLoss is True:
_feed_dict = {self.X: Xtest, self.Y: Ytest,
self.batch_th: Ytest.shape[0],
self.sigmaI: bonsaiObjSigmaI}
else:
_feed_dict = {self.X: Xtest, self.Y: Ytest,
self.sigmaI: bonsaiObjSigmaI}
else:
_feed_dict = {self.X: Xtest, self.Y: Ytest,
self.sigmaI: bonsaiObjSigmaI}

# This helps in direct testing instead of extracting the model out

testAcc, testLoss, regTestLoss, pred = sess.run(
[self.accuracy, self.loss, self.regLoss, self.prediction], feed_dict=_feed_dict)

if ihtDone == 0:
if (self.bonsaiObj.isRegression is False):
maxTestAcc = -10000
maxTestAccEpoch = i
elif (self.bonsaiObj.isRegression is True):
maxTestAcc = testAcc
maxTestAccEpoch = i

else:
if (self.bonsaiObj.isRegression is False):
if maxTestAcc <= testAcc:
maxTestAccEpoch = i
maxTestAcc = testAcc
self.saveParams(currDir)
self.saveParamsForSeeDot(currDir)
elif (self.bonsaiObj.isRegression is True):
print("Minimum Training MSE : ", np.mean(maxTestAcc))
if maxTestAcc >= testAcc:
# For regression , we're more interested in the minimum
# MSE.
maxTestAccEpoch = i
maxTestAcc = testAcc
self.saveParams(currDir)
self.saveParamsForSeeDot(currDir)

if (self.bonsaiObj.isRegression is True):
print("Testing MSE %g" % np.mean(testAcc), file=self.outFile)
else:
print("Test accuracy %g" % np.mean(testAcc), file=self.outFile)

if (self.bonsaiObj.isRegression is True):
testAcc = np.mean(testAcc)
else:
testAcc = testAcc
maxTestAcc = maxTestAcc

print("MarginLoss + RegLoss: " + str(testLoss - regTestLoss) +
" + " + str(regTestLoss) + " = " + str(testLoss) + "\n",
file=self.outFile)
self.outFile.flush()

bonsaiObjSigmaI = oldSigmaI

# sigmaI has to be set to infinity to ensure
# only a single path is used in inference
bonsaiObjSigmaI = 1e9
print("\nNon-Zero : " + str(self.getModelSize()[0]) + " Model Size: " +
str(float(self.getModelSize()[1]) / 1024.0) + " KB hasSparse: " +
str(self.getModelSize()[2]) + "\n", file=self.outFile)

if (self.bonsaiObj.isRegression is True):
maxTestAcc = np.mean(maxTestAcc)

if (self.bonsaiObj.isRegression is True):
print("For Regression, Minimum MSE at compressed" +
" model size(including early stopping): " +
str(maxTestAcc) + " at Epoch: " +
str(maxTestAccEpoch + 1) + "\nFinal Test" +
" MAE: " + str(testAcc), file=self.outFile)

resultFile.write("MinTestMSE: " + str(maxTestAcc) +
" at Epoch(totalEpochs): " +
str(maxTestAccEpoch + 1) +
"(" + str(totalEpochs) + ")" + " ModelSize: " +
str(float(self.getModelSize()[1]) / 1024.0) +
" KB hasSparse: " + str(self.getModelSize()[2]) +
" Param Directory: " +
str(os.path.abspath(currDir)) + "\n")

elif (self.bonsaiObj.isRegression is False):
print("For Classification, Maximum Test accuracy at compressed" +
" model size(including early stopping): " +
str(maxTestAcc) + " at Epoch: " +
str(maxTestAccEpoch + 1) + "\nFinal Test" +
" Accuracy: " + str(testAcc), file=self.outFile)

resultFile.write("MaxTestAcc: " + str(maxTestAcc) +
" at Epoch(totalEpochs): " +
str(maxTestAccEpoch + 1) +
"(" + str(totalEpochs) + ")" + " ModelSize: " +
str(float(self.getModelSize()[1]) / 1024.0) +
" KB hasSparse: " + str(self.getModelSize()[2]) +
" Param Directory: " +
str(os.path.abspath(currDir)) + "\n")
print("The Model Directory: " + currDir + "\n")

resultFile.close()
self.outFile.flush()

if self.outFile is not sys.stdout:
self.outFile.close()




5 changes: 5 additions & 0 deletions tf/edgeml_tf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ def mean_absolute_error(logits, label):
'''
return tf.reduce_mean(tf.abs(tf.subtract(logits, label)))

def mean_squared_error(logits, label):
'''
Function to compute the mean squared error.
'''
return tf.reduce_mean(tf.square(tf.subtract(logits, label)))

def hardThreshold(A, s):
'''
Expand Down