Skip to content
This repository has been archived by the owner on Nov 8, 2018. It is now read-only.

Model predicts same value for all the features. #59

Open
ankushreddy opened this issue Mar 8, 2018 · 1 comment
Open

Model predicts same value for all the features. #59

ankushreddy opened this issue Mar 8, 2018 · 1 comment

Comments

@ankushreddy
Copy link

My model is predicting same value for all the features.

I am trying to predict the color of images.
I have a dataframe or I created a dataframe with 12,000 thousand images. schema of the dataframe is.

  df.printSchema()
   root
       |-- Image_Id_new: string (nullable = true)
       |-- Color: string (nullable = true)
       |-- rawfeatures: vector (nullable = true)

  nb_classes = 11
  encoder = OneHotTransformer(nb_classes, input_col="Color", output_col="label_encoded")
  dataset_train = encoder.transform(final_train)
  dataset_test = encoder.transform(final_test)

  dataset_train = dataset_train.selectExpr("rawfeatures as features", "Color as label", "label_encoded")
 dataset_test = dataset_test.selectExpr("rawfeatures as features", "Color as label", "label_encoded")

Clear the dataset in the case you ran this cell before.

 dataset_train = dataset_train.select("features", "label", "label_encoded")
 dataset_test = dataset_test.select("features", "label", "label_encoded")

Allocate a MinMaxTransformer using Distributed Keras.

o_min -> original_minimum

n_min -> new_minimum

 transformer = MinMaxTransformer(n_min=0.0, n_max=1.0, \
                            o_min=0.0, o_max=250.0, \
                            input_col="features", \
                            output_col="features_normalized")

Transform the dataset.

 dataset_train = transformer.transform(dataset_train)
 dataset_test = transformer.transform(dataset_test)

  reshape_transformer = ReshapeTransformer("features_normalized", "matrix", (100, 100, 3))
 dataset_train = reshape_transformer.transform(dataset_train)
 dataset_test = reshape_transformer.transform(dataset_test)

 mlp = Sequential()
 mlp.add(Dense(11, input_shape=(30000,)))
  mlp.add(Activation('relu'))
  mlp.add(Dropout(0.2))
  mlp.add(Dense(128))
  mlp.add(Activation('relu'))
  mlp.add(Dropout(0.5))
  mlp.add(Dense(11))
  mlp.add(Activation('softmax'))
  mlp.summary()

 optimizer_mlp = 'adam'
 loss_mlp = 'categorical_crossentropy'

 def evaluate_accuracy(model, test_set, features="features_normalized_dense"):
     evaluator = AccuracyEvaluator(prediction_col="prediction_index", label_col="label")
      predictor = ModelPredictor(keras_model=model, features_col=features)
      transformer = LabelIndexTransformer(output_dim=nb_classes)
      test_set = test_set.select(features, "label")
      test_set = predictor.predict(test_set)
      test_set = transformer.transform(test_set)
      score = evaluator.evaluate(test_set)

        return score

  dataset_train = dataset_train.select("features_normalized", "matrix","label", "label_encoded")
  dataset_test = dataset_test.select("features_normalized", "matrix","label", "label_encoded")

  dense_transformer = DenseTransformer(input_col="features_normalized", 
               output_col="features_normalized_dense")
  dataset_train = dense_transformer.transform(dataset_train)
  dataset_test = dense_transformer.transform(dataset_test)
  dataset_train.repartition(num_workers)
  dataset_test.repartition(num_workers)

Assing the training and test set.

 training_set = dataset_train.repartition(num_workers)
  test_set = dataset_test.repartition(num_workers)

Cache them.

 training_set.persist(StorageLevel.MEMORY_AND_DISK_2)
 test_set.persist(StorageLevel.MEMORY_AND_DISK_2)

  print(training_set.count())

trainer = DOWNPOUR(keras_model=mlp, worker_optimizer=optimizer_mlp, loss=loss_mlp, num_workers=1,
batch_size=32, communication_window=32, num_epoch=5,
features_col="features_normalized_dense", label_col="label_encoded")
trained_model = trainer.train(training_set)

  print("Training time: " + str(trainer.get_training_time()))

Training time: 235.8617208
print("Accuracy: " + str(evaluate_accuracy(trained_model, test_set)))
Accuracy: 0.248927038627
evaluator = AccuracyEvaluator(prediction_col="prediction_index", label_col="label")
predictor = ModelPredictor(keras_model=trained_model, features_col="features_normalized_dense")
transformer = LabelIndexTransformer(output_dim=nb_classes)
test_set = test_set.select("features_normalized_dense", "label")
test_set = predictor.predict(test_set)

  test_set.select("label","prediction").show(truncate=False)

+-----+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|label|prediction |
+-----+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|8 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]|
|0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]|
|7 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]|
|4 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]|
|0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]|
|6 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]|
|10 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]|
|1 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]|
|0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]|
|3 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]|
|2 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]|
|10 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]|
|7 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]|
|3 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]|
|0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]|
|0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]|
|2 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]|
|0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]|
|1 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]|
|0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]|
+-----+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

@ankushreddy
Copy link
Author

hi @JoeriHermans can you please look at this and let me know if am missing anything.

Am trying to achieve it for the past two weeks.
Thank You for your help.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant