Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions Project2-Character-Recognition/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,28 @@ CUDA Character Recognition

**University of Pennsylvania, CIS 565: GPU Programming and Architecture, Project 2**

* (TODO) YOUR NAME HERE
* (TODO) [LinkedIn](), [personal website](), [twitter](), etc.
* Tested on: (TODO) Windows 22, i7-2222 @ 2.22GHz 22GB, GTX 222 222MB (Moore 2222 Lab)
* Tabatha Hickman
* LinkedIn:https://www.linkedin.com/in/tabatha-hickman-335987140/
* Tested on: Windows 10 Pro, i7-5600U CPU @ 2.60GHz 16GB, GeForce 840M (personal computer)

### (TODO: Your README)
## Neural Network Implementation

Include analysis, etc. (Remember, this is public, so don't put
anything here that you don't want to share with the world.)
This project's purpose was to create a neural network which does its computations on the GPU. I created a multi-layer perceptron with one hidden layer, so in total there are 3 layers (input, hidden, output). We evaluate the network by feeding information forward to the next layer. To process each new layer, I performed a summation for each output node on all the input nodes multiplied by the corresponding weight between those two nodes, then ran that sum through an activation function. In this case our function was ```f(x) = 1/(1+e^-x)```.

We want to find the best set of weights so that the outputs of the network are as accurate as possible. We do this by entering a training phase. First we start with random values for the weights. Then, provided with inputs and corresponding target outputs, we run the inputs through the network and compare the outputs with their targets and find the error associated. Then through backward propagation, we can go through each weight and update it based on the results so that next time the output is more accurate.

Once the network has been trained adequately, we can run new inputs on it and see if we get some good results.

Using provided weights for a working XOR neural network, I was able to verify my code correctly feeds forward and builds the network. I've also been able to produce my own fairly accurate weights for XOR: (This had a target error of 0.01)

```
Ran 13101 iterations of training
(0, 0) expected: 0.000000, result 0.071486
(0, 1) expected: 1.000000, result 0.930205
(1, 0) expected: 1.000000, result 0.923021
(1, 1) expected: 0.000000, result 0.063928
```

Unfortunately, I was having a lot of trouble extending this to character recognition. Training does not seem to be working - the error is huge and doesn't improve at all with further iterations. I attempted to debug this and started getting "CUDA grid launch failed" errors. Upon looking this up I found out this has to do with the TDR of my Debugger, but I can't find the place to change that setting.


Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ set(SOURCE_FILES

cuda_add_library(character_recognition
${SOURCE_FILES}
OPTIONS -arch=sm_20
OPTIONS -arch=sm_30
)
250 changes: 250 additions & 0 deletions Project2-Character-Recognition/character_recognition/mlp.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <thrust/random.h>
#include "common.h"
#include "mlp.h"

Expand All @@ -23,5 +24,254 @@ namespace CharacterRecognition {
}
*/

#define blockSize 128

__host__ __device__ unsigned int hash(unsigned int a) {
a = (a + 0x7ed55d16) + (a << 12);
a = (a ^ 0xc761c23c) ^ (a >> 19);
a = (a + 0x165667b1) + (a << 5);
a = (a + 0xd3a2646c) ^ (a << 9);
a = (a + 0xfd7046c5) + (a << 3);
a = (a ^ 0xb55a4f09) ^ (a >> 16);
return a;
}

__host__ __device__ float genRandom(float time, int index) {
thrust::default_random_engine rng(hash((int)(index * time)));
thrust::uniform_real_distribution<float> unitDistrib(-1, 1);

return (float)unitDistrib(rng);
}

__global__ void kernInitRandomWeights(int N, float* wtMat, float scale)
{
int index = (blockIdx.x * blockDim.x) + threadIdx.x;
if (index < N) {
float rand = genRandom(N, index);
wtMat[index] = scale * rand;
}
}

__global__ void kernInitZero(int N, float* data)
{
int index = (blockIdx.x * blockDim.x) + threadIdx.x;
if (index < N)
{
data[index] = 0;
}
}

__global__ void kernSumWeights(int iDim, int oDim, float* wtMat, float* idata, float* odata)
{
int index = (blockIdx.x * blockDim.x) + threadIdx.x;
if (index >= oDim) { return; }

for (int idx = 0; idx < iDim; idx++)
{
int wtIdx = idx * oDim + index;
odata[index] += wtMat[wtIdx] * idata[idx];
}
}

__global__ void kernActivationFxn(int N, float* idata, float* odata)
{
int index = (blockIdx.x * blockDim.x) + threadIdx.x;
if (index >= N) { return; }

float x = idata[index];
float e = exp(-x);
odata[index] = 1.0f / (1.0f + e);
}

__global__ void kernCalcErrors(int N, float* target, float* output, float* odata)
{
int index = (blockIdx.x * blockDim.x) + threadIdx.x;
if (index >= N) { return; }

odata[index] = target[index] - output[index];
}

__global__ void kernEditWeightsji(int N, int iDim, float lambda, float* hidden, float* errors, float* outputSums,
float* partialErr, float* wtMat)
{
// for hidden to output weights:
// delta = lambda * value of hidden node * (target - output) * derivative of f(x) (where x is the sum before it went in f(x) or is just the output??)
// derivative of f = f * (1-f)

int index = (blockIdx.x * blockDim.x) + threadIdx.x;
if (index >= N) { return; }

int i = index % iDim;
int j = index / iDim;

float x = outputSums[i];
float fx = 1.0f / (1.0f + exp(-x));
partialErr[i] = errors[i] * fx * (1 - fx);
float deltaW = lambda * hidden[j] * partialErr[i];

wtMat[index] += deltaW;
}

__global__ void kernEditWeightskj(int N, int jDim, int iDim, float lambda, float* input, float* hiddenSums,
float* partialErr, float* wji,
float* wtMat)
{
// for hidden to output weights:
// delta = lambda * value of input node * derivative of f(x) *
// derivative of f = f * (1-f)

int index = (blockIdx.x * blockDim.x) + threadIdx.x;
if (index >= N) { return; }

int j = index % jDim;
int k = index / jDim;

float sumPropErrs = 0;
for (int i = 0; i < iDim; i++)
{
sumPropErrs += partialErr[i] * wji[j + i * jDim];
}

float x = hiddenSums[j];
float fx = 1.0f / (1.0f + exp(-x));
float deltaW = lambda * input[k] * sumPropErrs * fx * (1 - fx);

wtMat[index] += deltaW;
}

void makeWeightMat(int n, float* data)
{
float* dev_data;
cudaMalloc((void**)&dev_data, n * sizeof(float));

kernInitRandomWeights << <n, blockSize >> > (n, dev_data, 30);

cudaMemcpy(data, dev_data, n * sizeof(float), cudaMemcpyDeviceToHost);
cudaFree(dev_data);
}

// TODO: implement required elements for MLP sections 1 and 2 here
float mlpTrain(int i, int j, int k, float* odata, float* idata, float* wkj, float* wji, float* target)
{
float *dev_input, *dev_hidden, *dev_output;
float *dev_hiddenSums, *dev_outputSums;
float *dev_wkj, *dev_wji;
float *dev_target, *dev_errors, *dev_partialErr, *dev_tempwji;

cudaMalloc((void**)&dev_input, k * sizeof(float));
cudaMalloc((void**)&dev_hidden, j * sizeof(float));
cudaMalloc((void**)&dev_output, i * sizeof(float));
cudaMemcpy(dev_input, idata, k * sizeof(float), cudaMemcpyHostToDevice);

cudaMalloc((void**)&dev_hiddenSums, j * sizeof(float));
cudaMalloc((void**)&dev_outputSums, i * sizeof(float));

cudaMalloc((void**)&dev_wkj, k * j * sizeof(float));
cudaMalloc((void**)&dev_wji, j * i * sizeof(float));
cudaMemcpy(dev_wkj, wkj, k * j * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(dev_wji, wji, j * i * sizeof(float), cudaMemcpyHostToDevice);

cudaMalloc((void**)&dev_target, i * sizeof(float));
cudaMalloc((void**)&dev_errors, i * sizeof(float));
cudaMalloc((void**)&dev_partialErr, i * sizeof(float));
cudaMalloc((void**)&dev_tempwji, i * j * sizeof(float));
cudaMemcpy(dev_target, target, i * sizeof(float), cudaMemcpyHostToDevice);

// initialize non input buffers to zeros
kernInitZero << <j, blockSize >> > (j, dev_hidden);
kernInitZero << <i, blockSize >> > (i, dev_output);

// input -> hidden
kernSumWeights << <j, blockSize >> > (k, j, dev_wkj, dev_input, dev_hiddenSums);
kernActivationFxn << <j, blockSize >> > (j, dev_hiddenSums, dev_hidden);

// hidden -> output
kernSumWeights << <i, blockSize >> > (j, i, dev_wji, dev_hidden, dev_outputSums);
kernActivationFxn << <i, blockSize >> > (i, dev_outputSums, dev_output);

// calculate error, lambda
kernCalcErrors << <i, blockSize >> > (i, dev_target, dev_output, dev_errors);

float* errs = new float[i];
cudaMemcpy(errs, dev_errors, i * sizeof(float), cudaMemcpyDeviceToHost);
float sumErr = 0;
for (int e = 0; e < i; e++)
{
sumErr += (errs[e]*errs[e]);
}
sumErr /= 2.0f;
float lambda = sumErr;

// update weights
cudaMemcpy(dev_tempwji, dev_wji, j * i * sizeof(float), cudaMemcpyDeviceToDevice);
kernEditWeightsji << <j*i, blockSize >> > (j*i, i, lambda, dev_hidden, dev_errors, dev_output,
dev_partialErr, dev_wji);
kernEditWeightskj << <k*j, blockSize >> > (k*j, j, i, lambda, dev_input, dev_hidden, dev_partialErr,
dev_tempwji, dev_wkj);

cudaMemcpy(odata, dev_output, i * sizeof(float), cudaMemcpyDeviceToHost);
cudaMemcpy(wkj, dev_wkj, k * j * sizeof(float), cudaMemcpyDeviceToHost);
cudaMemcpy(wji, dev_wji, j * i * sizeof(float), cudaMemcpyDeviceToHost);

cudaFree(dev_input);
cudaFree(dev_hidden);
cudaFree(dev_output);

cudaFree(dev_hiddenSums);
cudaFree(dev_outputSums);

cudaFree(dev_wkj);
cudaFree(dev_wji);

cudaFree(dev_target);
cudaFree(dev_errors);
cudaFree(dev_partialErr);
cudaFree(dev_tempwji);

return sumErr;
}

void mlpRun(int i, int j, int k, float* odata, float* idata, float* wkj, float* wji)
{
float *dev_input, *dev_hidden, *dev_output;
float *dev_hiddenSums, *dev_outputSums;
float *dev_wkj, *dev_wji;

cudaMalloc((void**)&dev_input, k * sizeof(float));
cudaMalloc((void**)&dev_hidden, j * sizeof(float));
cudaMalloc((void**)&dev_output, i * sizeof(float));
cudaMemcpy(dev_input, idata, k * sizeof(float), cudaMemcpyHostToDevice);

cudaMalloc((void**)&dev_hiddenSums, j * sizeof(float));
cudaMalloc((void**)&dev_outputSums, i * sizeof(float));

cudaMalloc((void**)&dev_wkj, k * j * sizeof(float));
cudaMalloc((void**)&dev_wji, j * i * sizeof(float));
cudaMemcpy(dev_wkj, wkj, k * j * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(dev_wji, wji, j * i * sizeof(float), cudaMemcpyHostToDevice);

// initialize non input buffers to zeros
kernInitZero << <j, blockSize >> > (j, dev_hidden);
kernInitZero << <i, blockSize >> > (i, dev_output);

// input -> hidden
kernSumWeights << <j, blockSize >> > (k, j, dev_wkj, dev_input, dev_hiddenSums);
kernActivationFxn << <j, blockSize >> > (j, dev_hiddenSums, dev_hidden);

// hidden -> output
kernSumWeights << <i, blockSize >> > (j, i, dev_wji, dev_hidden, dev_outputSums);
kernActivationFxn << <i, blockSize >> > (i, dev_outputSums, dev_output);

cudaMemcpy(odata, dev_output, i * sizeof(float), cudaMemcpyDeviceToHost);

cudaFree(dev_input);
cudaFree(dev_hidden);
cudaFree(dev_output);

cudaFree(dev_hiddenSums);
cudaFree(dev_outputSums);

cudaFree(dev_wkj);
cudaFree(dev_wji);
}
}
5 changes: 5 additions & 0 deletions Project2-Character-Recognition/character_recognition/mlp.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,10 @@
namespace CharacterRecognition {
Common::PerformanceTimer& timer();

void makeWeightMat(int n, float* data);

// TODO: implement required elements for MLP sections 1 and 2 here
float mlpTrain(int i, int j, int k, float* odata, float* idata, float* wkj, float* wji, float* target);

void mlpRun(int i, int j, int k, float* odata, float* idata, float* wkj, float* wji);
}
Loading