Skip to content

Commit efd41fd

Browse files
author
Paolo Di Francesco
committed
Change dataset and add train, validation, test
1 parent ab51504 commit efd41fd

File tree

4 files changed

+72
-181
lines changed

4 files changed

+72
-181
lines changed

scikit_learn_script_mode_local_training_and_serving/code/scikit_learn_iris.py scikit_learn_script_mode_local_training_and_serving/code/scikit_learn_california.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import joblib
2020
import pandas as pd
2121
from sklearn import tree
22+
from sklearn.metrics import mean_squared_error
2223

2324
if __name__ == "__main__":
2425
print("Training Started")
@@ -31,6 +32,7 @@
3132
parser.add_argument("--output-data-dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"])
3233
parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"])
3334
parser.add_argument("--train", type=str, default=os.environ["SM_CHANNEL_TRAIN"])
35+
parser.add_argument("--validation", type=str, default=os.environ["SM_CHANNEL_VALIDATION"])
3436

3537
args = parser.parse_args()
3638
print("Got Args: {}".format(args))
@@ -57,10 +59,20 @@
5759
# as your training my require in the ArgumentParser above.
5860
max_leaf_nodes = args.max_leaf_nodes
5961

60-
# Now use scikit-learn's decision tree classifier to train the model.
61-
clf = tree.DecisionTreeClassifier(max_leaf_nodes=max_leaf_nodes)
62+
# Now use scikit-learn's decision tree regression to train the model.
63+
clf = tree.DecisionTreeRegressor(max_leaf_nodes=max_leaf_nodes)
6264
clf = clf.fit(train_X, train_y)
6365

66+
input_files = [os.path.join(args.validation, file) for file in os.listdir(args.validation)]
67+
raw_data = [pd.read_csv(file, header=None, engine="python") for file in input_files]
68+
validation_data = pd.concat(raw_data)
69+
# labels are in the first column
70+
validation_y = validation_data.iloc[:, 0]
71+
validation_X = validation_data.iloc[:, 1:]
72+
#
73+
predictions = clf.predict(validation_X)
74+
error = mean_squared_error(predictions, validation_y)
75+
print(f"RMSE: {error}")
6476
# Print the coefficients of the trained classifier, and save the coefficients
6577
joblib.dump(clf, os.path.join(args.model_dir, "model.joblib"))
6678

scikit_learn_script_mode_local_training_and_serving/data/iris.csv

-150
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
numpy
22
pandas
3+
sklearn
34
sagemaker>=2.0.0<3.0.0
45
sagemaker[local]

scikit_learn_script_mode_local_training_and_serving/scikit_learn_script_mode_local_training_and_serving.py

+57-29
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# This is a sample Python program that trains a simple scikit-learn model on the Iris dataset.
1+
# This is a sample Python program that trains a simple scikit-learn model on the California dataset.
22
# This implementation will work on your *local computer* or in the *AWS Cloud*.
33
#
44
# Prerequisites:
@@ -16,40 +16,54 @@
1616
import os
1717

1818
from sagemaker.sklearn import SKLearn
19+
import sagemaker
20+
import boto3
1921
from sklearn import datasets
22+
from sklearn.model_selection import train_test_split
23+
from sklearn.metrics import mean_squared_error
2024

21-
DUMMY_IAM_ROLE = 'arn:aws:iam::111111111111:role/service-role/AmazonSageMaker-ExecutionRole-20200101T000001'
25+
local_mode = True
26+
27+
if local_mode:
28+
instance_type = "local"
29+
IAM_ROLE = 'arn:aws:iam::111111111111:role/service-role/AmazonSageMaker-ExecutionRole-20200101T000001'
30+
else:
31+
instance_type = "ml.m5.xlarge"
32+
IAM_ROLE = 'arn:aws:iam::<ACCOUNT>:role/service-role/AmazonSageMaker-ExecutionRole-XXX'
33+
34+
sess = sagemaker.Session()
35+
bucket = sess.default_bucket() # Set a default S3 bucket
36+
prefix = 'DEMO-local-and-managed-infrastructure'
2237

2338
def download_training_and_eval_data():
24-
if os.path.isfile('./data/iris.csv'):
25-
print('Training and dataset exist. Skipping Download')
26-
else:
27-
print('Downloading training dataset')
39+
print('Downloading training dataset')
2840

29-
# Load Iris dataset, then join labels and features
30-
iris = datasets.load_iris()
31-
joined_iris = np.insert(iris.data, 0, iris.target, axis=1)
41+
# Load California Housing dataset, then join labels and features
42+
california = datasets.fetch_california_housing()
43+
dataset = np.insert(california.data, 0, california.target, axis=1)
44+
# Create directory and write csv
45+
os.makedirs("./data/train", exist_ok=True)
46+
os.makedirs("./data/validation", exist_ok=True)
47+
os.makedirs("./data/test", exist_ok=True)
3248

33-
# Create directory and write csv
34-
os.makedirs("./data", exist_ok=True)
35-
np.savetxt("./data/iris.csv", joined_iris, delimiter=",", fmt="%1.1f, %1.3f, %1.3f, %1.3f, %1.3f")
49+
train, other = train_test_split(dataset, test_size=0.3)
50+
validation, test = train_test_split(other, test_size=0.5)
3651

37-
print('Downloading completed')
52+
np.savetxt("./data/train/california_train.csv", train, delimiter=",")
53+
np.savetxt("./data/validation/california_validation.csv", validation, delimiter=",")
54+
np.savetxt("./data/test/california_test.csv", test, delimiter=",")
55+
56+
print('Downloading completed')
3857

3958
def do_inference_on_local_endpoint(predictor):
4059
print(f'\nStarting Inference on endpoint (local).')
41-
shape = pd.read_csv("data/iris.csv", header=None)
42-
43-
a = [50 * i for i in range(3)]
44-
b = [40 + i for i in range(10)]
45-
indices = [i + j for i, j in itertools.product(a, b)]
46-
47-
test_data = shape.iloc[indices[:-1]]
60+
test_data = pd.read_csv("data/test/california_test.csv", header=None)
4861
test_X = test_data.iloc[:, 1:]
4962
test_y = test_data.iloc[:, 0]
50-
print("Predictions: {}".format(predictor.predict(test_X.values)))
63+
predictions = predictor.predict(test_X.values)
64+
print("Predictions: {}".format(predictions))
5165
print("Actual: {}".format(test_y.values))
52-
66+
print(f"RMSE: {mean_squared_error(predictions, test_y.values)}")
5367

5468
def main():
5569
download_training_and_eval_data()
@@ -58,21 +72,35 @@ def main():
5872
print('Note: if launching for the first time in local mode, container image download might take a few minutes to complete.')
5973

6074
sklearn = SKLearn(
61-
entry_point="scikit_learn_iris.py",
75+
entry_point="scikit_learn_california.py",
6276
source_dir='code',
6377
framework_version="1.0-1",
64-
instance_type="local",
65-
role=DUMMY_IAM_ROLE,
78+
instance_type=instance_type,
79+
role=IAM_ROLE,
6680
hyperparameters={"max_leaf_nodes": 30},
6781
)
6882

69-
train_input = "file://./data/iris.csv"
83+
if local_mode:
84+
train_input = "file://./data/train/california_train.csv"
85+
validation_input = "file://./data/validation/california_validation.csv"
86+
else:
87+
# upload data to S3
88+
boto3.Session().resource('s3').Bucket(bucket).Object(os.path.join(prefix, 'data/train/california_train.csv')).upload_file('data/train/california_train.csv')
89+
boto3.Session().resource('s3').Bucket(bucket).Object(os.path.join(prefix, 'data/validation/california_validation.csv')).upload_file('data/validation/california_validation.csv')
90+
boto3.Session().resource('s3').Bucket(bucket).Object(os.path.join(prefix, 'data/test/california_test.csv')).upload_file('data/test/california_test.csv')
7091

71-
sklearn.fit({"train": train_input})
92+
train_input =f"s3://{bucket}/{prefix}/data/train/california_train.csv"
93+
validation_input =f"s3://{bucket}/{prefix}/data/validation/california_validation.csv"
94+
test_input =f"s3://{bucket}/{prefix}/data/test/california_test.csv"
95+
96+
sklearn.fit({"train": train_input, "validation": validation_input})
7297
print('Completed model training')
7398

74-
print('Deploying endpoint in local mode')
75-
predictor = sklearn.deploy(initial_instance_count=1, instance_type='local')
99+
if local_mode:
100+
print('Deploying endpoint in local mode')
101+
else:
102+
print(f"deploying on the SageMaker managed infrastructure using a {instance_type} instance type")
103+
predictor = sklearn.deploy(initial_instance_count=1, instance_type=instance_type)
76104

77105
do_inference_on_local_endpoint(predictor)
78106

0 commit comments

Comments
 (0)