Commit e9528a3b authored by dasharatha.vamshi's avatar dasharatha.vamshi

changes

parent 7273d48b
import json
import traceback
from scripts.common.config_parser import config
from scripts.common.constants import ModelObjectConstants, ComponentExceptions
......@@ -28,6 +29,8 @@ class RandomForest:
"bootstrap": self.bootstrap,
"criterion": self.criterion
}
self.meta_data_path = os.path.join(self.component_out_dir, "metadata.json")
self.meta_data_content = self.query['meta_data']
def random_forest(self):
try:
......@@ -35,10 +38,14 @@ class RandomForest:
logger.info("These are the parameters used " + str(self.parameters))
rf = RandomForestRegressor(n_estimators=self.n_estimators, min_samples_split=self.min_samples_split,
min_samples_leaf=self.min_samples_leaf, max_features=self.max_features,
max_depth=self.max_depth, bootstrap=self.bootstrap,criterion=self.criterion)
max_depth=self.max_depth, bootstrap=self.bootstrap, criterion=self.criterion)
filename = 'random_forest.pkl'
logger.info("Pickling the model.....")
pickle.dump(rf, open(os.path.join(self.component_out_dir, filename), 'wb'))
logger.info("Creating metadata json file.........")
json_object = json.dumps(self.meta_data_content, indent=4)
with open(self.meta_data_path, "w") as outfile:
outfile.write(json_object)
return True
except Exception as e:
logger.info(traceback.format_exc())
......
......@@ -27,7 +27,6 @@ else:
# uncomment for Testing
# os.environ['pipeline_id'] = "pipe1"
# os.environ['model_name'] = 'RANDOMFOREST'
# os.environ['bootstrap'] = "True"
# os.environ['max_features'] = 'sqrt'
# os.environ['n_estimators'] = "800"
......@@ -102,7 +101,21 @@ config = {
'min_samples_split': min_samples_split,
'n_estimators': n_estimators,
'criterion': criterion,
'max_depth': max_depth
'max_depth': max_depth,
'meta_data': {
"model_name": model_name,
"model_params": {
'bootstrap': bootstrap,
'max_features': max_features,
'min_samples_leaf': min_samples_leaf,
'min_samples_split': min_samples_split,
'n_estimators': n_estimators,
'criterion': criterion,
'max_depth': max_depth,
},
"serializedObjectType": "pkl",
"framework": "sklearn"
}
}
if not os.path.exists(config['shared_mount_base_ai_job']):
sys.stderr.write("Shared path does not exist!" + "\n")
......
......@@ -2,7 +2,7 @@
class ModelObjectConstants:
NEXT_COMPONENT = "ModelFitting"
NEXT_COMPONENT = "TrainModel"
COMPONENT_NAME = "RandomForest"
HTTP = "http://"
RANDOMFOREST = "randomforest"
......
{
"model_name": "randomforest",
"model_params": {
"bootstrap": true,
"max_features": "sqrt",
"min_samples_leaf": 4,
"min_samples_split": 10,
"n_estimators": 900,
"criterion": "mse",
"max_depth": 100
},
"serializedObjectType": "pkl",
"framework": "sklearn"
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment