diff --git a/Dockerfile b/Dockerfile
index 888a8da..0454cf0 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -11,6 +11,10 @@ RUN pip3 install --upgrade pip
ADD requirements.txt requirements.txt
RUN pip3 install -r requirements.txt
+# In this example, we can define the hugging face model as an ENV variable
+# and from here pass it to download.py & app.py
+ENV HF_MODEL_NAME bert-base-uncased
+
# We add the banana boilerplate here
ADD server.py .
diff --git a/README.md b/README.md
index 89068bd..4ff47d6 100644
--- a/README.md
+++ b/README.md
@@ -17,4 +17,37 @@ Generalize this framework to [deploy anything on Banana](https://docs.banana.dev
+# Local testing
+
+## With docker
+
+To test the Serverless Framework with docker locally, you need to build the docker container and then run it.
+In the root of this directory, run:
+```
+docker build . -t serverless-template
+```
+After which you can run the container. Here we also forward the port to access the localhost url outside of the
+container and enable gpu acceleration.
+```
+docker run -p 8000:8000 --gpus=all serverless-template
+```
+
+## Without docker
+
+Testing your code without docker is straight forward. Remember to pass in the Hugging Face model name as
+an ENV variable. In this case:
+```
+export HF_MODEL_NAME=bert-base-uncased
+```
+Make sure you have the required dependencies:
+```
+pip3 install -r requirements.txt
+```
+And then simply run the server.py
+```
+python3 server.py
+```
+
+
+
## Use Banana for scale.
diff --git a/app.py b/app.py
index 7f6b061..2137df3 100644
--- a/app.py
+++ b/app.py
@@ -1,13 +1,17 @@
from transformers import pipeline
import torch
+import os
# Init is ran on server startup
# Load your model to GPU as a global variable here using the variable name "model"
def init():
global model
+
+ # In this example, we get the model name as an ENV variable defined in the Dockerfile
+ hf_model_name = os.getenv("HF_MODEL_NAME")
device = 0 if torch.cuda.is_available() else -1
- model = pipeline('fill-mask', model='bert-base-uncased', device=device)
+ model = pipeline('fill-mask', model=hf_model_name, device=device)
# Inference is ran for every server call
# Reference your preloaded global model variable here.
diff --git a/download.py b/download.py
index 9f2956d..b9fc706 100644
--- a/download.py
+++ b/download.py
@@ -4,10 +4,15 @@
# In this example: A Huggingface BERT model
from transformers import pipeline
+import os
def download_model():
+
+ # In this example, we get the model name as an ENV variable defined in the Dockerfile
+ hf_model_name = os.getenv("HF_MODEL_NAME")
+
# do a dry run of loading the huggingface model, which will download weights
- pipeline('fill-mask', model='bert-base-uncased')
+ pipeline('fill-mask', model=hf_model_name)
if __name__ == "__main__":
download_model()
\ No newline at end of file