Cainvas

Drug classification

Credit: AITS Cainvas Community

Photo by Vadim Gromov on Dribbble

Training a deep learning model to prescribe a drug based on the patient's data.

In [1]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.models import Sequential
from tensorflow.keras.callbacks import EarlyStopping
import random
import matplotlib.pyplot as plt

Dataset

On Kaggle by Pratham Tripathi.

The datatset is a CSV file with the features regarding a patient that affects drug prescriptions like age, sex, BP level, cholestrol, and sodium-potassium ratio and the corresponding drug prescribes in each case.

In [2]:
df = pd.read_csv('https://cainvas-static.s3.amazonaws.com/media/user_data/cainvas-admin/drug200.csv')
df
Out[2]:
Age Sex BP Cholesterol Na_to_K Drug
0 23 F HIGH HIGH 25.355 DrugY
1 47 M LOW HIGH 13.093 drugC
2 47 M LOW HIGH 10.114 drugC
3 28 F NORMAL HIGH 7.798 drugX
4 61 F LOW HIGH 18.043 DrugY
... ... ... ... ... ... ...
195 56 F LOW HIGH 11.567 drugC
196 16 M LOW HIGH 12.006 drugC
197 52 M NORMAL HIGH 9.894 drugX
198 23 M NORMAL NORMAL 14.020 drugX
199 40 F LOW NORMAL 11.349 drugX

200 rows × 6 columns

In [3]:
df['Drug'].value_counts()
Out[3]:
DrugY    91
drugX    54
drugA    23
drugC    16
drugB    16
Name: Drug, dtype: int64

This is an unbalanced dataset.

Preprocessing

Balancing the dataset

In order to balance the dataset, there are two options,

  • upsampling - resample the values to make their count equal to the class label with the higher count (here, 1655).
  • downsampling - pick n samples from each class label where n = number of samples in class with least count (here, 176)

Here, we will be upsampling.

In [4]:
categories = np.unique(df.Drug.to_list())
df_balanced = pd.DataFrame()

for i in range(len(categories)):
    # separating into individual dataframes, one for each class 
    dfi = df[df['Drug'] == categories[i]]
    # resampling
    dfi = dfi.sample(91, replace = True)
    # appending all to one to form a final balanced dataframe
    df_balanced = df_balanced.append(dfi)

df_balanced['Drug'].value_counts()
Out[4]:
drugX    91
drugB    91
drugC    91
drugA    91
DrugY    91
Name: Drug, dtype: int64

Categorical variables

The 'sex' column does not define a range and thus is one-hot encoded while changing from a categorical variable to a numerical attribute

In [5]:
dfx = pd.get_dummies(df_balanced[df_balanced.columns[:-1]], drop_first = True, columns = ['Sex'])
dfx
Out[5]:
Age BP Cholesterol Na_to_K Sex_M
49 28 LOW HIGH 19.796 0
172 39 NORMAL NORMAL 17.225 0
178 39 NORMAL HIGH 15.969 1
62 67 LOW NORMAL 20.693 1
71 28 NORMAL HIGH 19.675 0
... ... ... ... ... ...
16 69 LOW NORMAL 11.455 1
95 36 LOW NORMAL 11.424 1
58 60 NORMAL NORMAL 10.091 1
35 46 NORMAL NORMAL 7.285 1
181 59 NORMAL HIGH 13.884 0

455 rows × 5 columns

In [6]:
print("Values in BP column:", np.unique(dfx['BP']))
print("Values in Cholesterol column:", np.unique(dfx['Cholesterol']))
Values in BP column: ['HIGH' 'LOW' 'NORMAL']
Values in Cholesterol column: ['HIGH' 'NORMAL']

The values in the columns Cholesterol and BP represent range as seen by the values above.

In [7]:
le_bp = LabelEncoder()
le_bp.fit(['LOW', 'NORMAL', 'HIGH'])
dfx['BP'] = le_bp.transform(dfx['BP'], )
print("BP classes:", le_bp.classes_)

le_ch = LabelEncoder()
le_ch.fit(['NORMAL', 'HIGH'])
dfx['Cholesterol'] = le_ch.transform(dfx['Cholesterol'])
print("Cholesterol classes:", le_ch.classes_)

print(dfx)
BP classes: ['HIGH' 'LOW' 'NORMAL']
Cholesterol classes: ['HIGH' 'NORMAL']
     Age  BP  Cholesterol  Na_to_K  Sex_M
49    28   1            0   19.796      0
172   39   2            1   17.225      0
178   39   2            0   15.969      1
62    67   1            1   20.693      1
71    28   2            0   19.675      0
..   ...  ..          ...      ...    ...
16    69   1            1   11.455      1
95    36   1            1   11.424      1
58    60   2            1   10.091      1
35    46   2            1    7.285      1
181   59   2            0   13.884      0

[455 rows x 5 columns]

Since this is a classification problem, the output of the model which is now as an integer should be one-hot encoded.

In [8]:
df_cat = pd.get_dummies(df_balanced['Drug'])
df_cat
Out[8]:
DrugY drugA drugB drugC drugX
49 1 0 0 0 0
172 1 0 0 0 0
178 1 0 0 0 0
62 1 0 0 0 0
71 1 0 0 0 0
... ... ... ... ... ...
16 0 0 0 0 1
95 0 0 0 0 1
58 0 0 0 0 1
35 0 0 0 0 1
181 0 0 0 0 1

455 rows × 5 columns

In [9]:
# defining the input and output columns to separate the dataset in the later cells.

input_columns = dfx.columns.to_list()
output_columns = df_cat.columns.to_list()

print("Number of input columns: ", len(input_columns))
#print("Input columns: ", ', '.join(input_columns))

print("Number of output columns: ", len(output_columns))
#print("Output columns: ", ', '.join(output_columns))
Number of input columns:  5
Number of output columns:  5
In [10]:
for i in output_columns:
    dfx[i] = df_cat[i]

del df_cat

dfx
Out[10]:
Age BP Cholesterol Na_to_K Sex_M DrugY drugA drugB drugC drugX
49 28 1 0 19.796 0 1 0 0 0 0
172 39 2 1 17.225 0 1 0 0 0 0
178 39 2 0 15.969 1 1 0 0 0 0
62 67 1 1 20.693 1 1 0 0 0 0
71 28 2 0 19.675 0 1 0 0 0 0
... ... ... ... ... ... ... ... ... ... ...
16 69 1 1 11.455 1 0 0 0 0 1
95 36 1 1 11.424 1 0 0 0 0 1
58 60 2 1 10.091 1 0 0 0 0 1
35 46 2 1 7.285 1 0 0 0 0 1
181 59 2 0 13.884 0 0 0 0 0 1

455 rows × 10 columns

Train test split

In [11]:
# Splitting into train, val and test set -- 80-10-10 split

# First, an 80-20 split
train_df, val_test_df = train_test_split(dfx, test_size = 0.2, random_state = 13)

# Then split the 20% into half
val_df, test_df = train_test_split(val_test_df, test_size = 0.5, random_state = 13)

print("Number of samples in...")
print("Training set: ", len(train_df))
print("Validation set: ", len(val_df))
print("Testing set: ", len(test_df))
Number of samples in...
Training set:  364
Validation set:  45
Testing set:  46
In [12]:
# Splitting into X (input) and y (output)

Xtrain, ytrain = np.array(train_df[input_columns]), np.array(train_df[output_columns])

Xval, yval = np.array(val_df[input_columns]), np.array(val_df[output_columns])

Xtest, ytest = np.array(test_df[input_columns]), np.array(test_df[output_columns])

Scaling the values

In [13]:
# Each feature has a different range. 
# Using min_max_scaler to scale them to values in the range [0,1].

min_max_scaler = MinMaxScaler()

# Fit on training set alone
Xtrain = min_max_scaler.fit_transform(Xtrain)

# Use it to transform val and test input
Xval = min_max_scaler.transform(Xval)
Xtest = min_max_scaler.transform(Xtest)

Model

In [14]:
model = Sequential([
    Dense(1024, activation = 'relu', input_shape = Xtrain[0].shape),
    Dense(512, activation = 'relu'),
    Dense(256, activation = 'relu'),
    Dense(64, activation = 'relu'),
    Dense(len(output_columns), activation = 'softmax')
])

cb = [EarlyStopping(monitor = 'val_loss', patience=8, restore_best_weights=True)]
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 1024)              6144      
_________________________________________________________________
dense_1 (Dense)              (None, 512)               524800    
_________________________________________________________________
dense_2 (Dense)              (None, 256)               131328    
_________________________________________________________________
dense_3 (Dense)              (None, 64)                16448     
_________________________________________________________________
dense_4 (Dense)              (None, 5)                 325       
=================================================================
Total params: 679,045
Trainable params: 679,045
Non-trainable params: 0
_________________________________________________________________
In [15]:
model.compile(optimizer=Adam(0.01), loss=CategoricalCrossentropy(), metrics=['accuracy'])

history = model.fit(Xtrain, ytrain, validation_data = (Xval, yval), epochs=64, callbacks=cb)
Epoch 1/64
12/12 [==============================] - 0s 13ms/step - loss: 1.3218 - accuracy: 0.4451 - val_loss: 0.7544 - val_accuracy: 0.6889
Epoch 2/64
12/12 [==============================] - 0s 3ms/step - loss: 0.6025 - accuracy: 0.7637 - val_loss: 0.3085 - val_accuracy: 0.8667
Epoch 3/64
12/12 [==============================] - 0s 3ms/step - loss: 0.3834 - accuracy: 0.8681 - val_loss: 0.1606 - val_accuracy: 0.9778
Epoch 4/64
12/12 [==============================] - 0s 3ms/step - loss: 0.2477 - accuracy: 0.9258 - val_loss: 1.4389 - val_accuracy: 0.7333
Epoch 5/64
12/12 [==============================] - 0s 3ms/step - loss: 0.7271 - accuracy: 0.7995 - val_loss: 0.2889 - val_accuracy: 0.9556
Epoch 6/64
12/12 [==============================] - 0s 3ms/step - loss: 0.3600 - accuracy: 0.8901 - val_loss: 0.3921 - val_accuracy: 0.8889
Epoch 7/64
12/12 [==============================] - 0s 3ms/step - loss: 0.3028 - accuracy: 0.8929 - val_loss: 0.0872 - val_accuracy: 0.9556
Epoch 8/64
12/12 [==============================] - 0s 3ms/step - loss: 0.1238 - accuracy: 0.9588 - val_loss: 0.0280 - val_accuracy: 1.0000
Epoch 9/64
12/12 [==============================] - 0s 3ms/step - loss: 0.0576 - accuracy: 0.9725 - val_loss: 0.0210 - val_accuracy: 1.0000
Epoch 10/64
12/12 [==============================] - 0s 3ms/step - loss: 0.2064 - accuracy: 0.9615 - val_loss: 1.0279 - val_accuracy: 0.8444
Epoch 11/64
12/12 [==============================] - 0s 3ms/step - loss: 0.8020 - accuracy: 0.8654 - val_loss: 0.3219 - val_accuracy: 0.8889
Epoch 12/64
12/12 [==============================] - 0s 3ms/step - loss: 0.3034 - accuracy: 0.8984 - val_loss: 0.3241 - val_accuracy: 0.8667
Epoch 13/64
12/12 [==============================] - 0s 3ms/step - loss: 0.2108 - accuracy: 0.9148 - val_loss: 0.2620 - val_accuracy: 0.8444
Epoch 14/64
12/12 [==============================] - 0s 3ms/step - loss: 0.1753 - accuracy: 0.9478 - val_loss: 0.1746 - val_accuracy: 0.8889
Epoch 15/64
12/12 [==============================] - 0s 3ms/step - loss: 0.1110 - accuracy: 0.9670 - val_loss: 0.2109 - val_accuracy: 0.9556
Epoch 16/64
12/12 [==============================] - 0s 3ms/step - loss: 0.0621 - accuracy: 0.9808 - val_loss: 0.0125 - val_accuracy: 1.0000
Epoch 17/64
12/12 [==============================] - 0s 3ms/step - loss: 0.1596 - accuracy: 0.9533 - val_loss: 0.1547 - val_accuracy: 0.9778
Epoch 18/64
12/12 [==============================] - 0s 3ms/step - loss: 0.2892 - accuracy: 0.9203 - val_loss: 0.3302 - val_accuracy: 0.9111
Epoch 19/64
12/12 [==============================] - 0s 3ms/step - loss: 0.1974 - accuracy: 0.9505 - val_loss: 0.0508 - val_accuracy: 1.0000
Epoch 20/64
12/12 [==============================] - 0s 3ms/step - loss: 0.0589 - accuracy: 0.9890 - val_loss: 0.0054 - val_accuracy: 1.0000
Epoch 21/64
12/12 [==============================] - 0s 3ms/step - loss: 0.0785 - accuracy: 0.9753 - val_loss: 0.0038 - val_accuracy: 1.0000
Epoch 22/64
12/12 [==============================] - 0s 3ms/step - loss: 0.2819 - accuracy: 0.9423 - val_loss: 0.2005 - val_accuracy: 0.9333
Epoch 23/64
12/12 [==============================] - 0s 3ms/step - loss: 0.2469 - accuracy: 0.9313 - val_loss: 0.0641 - val_accuracy: 0.9778
Epoch 24/64
12/12 [==============================] - 0s 3ms/step - loss: 0.0993 - accuracy: 0.9725 - val_loss: 0.2020 - val_accuracy: 0.9556
Epoch 25/64
12/12 [==============================] - 0s 3ms/step - loss: 0.1079 - accuracy: 0.9615 - val_loss: 0.1907 - val_accuracy: 0.9556
Epoch 26/64
12/12 [==============================] - 0s 3ms/step - loss: 0.0567 - accuracy: 0.9863 - val_loss: 0.0198 - val_accuracy: 1.0000
Epoch 27/64
12/12 [==============================] - 0s 3ms/step - loss: 0.0354 - accuracy: 0.9863 - val_loss: 0.0048 - val_accuracy: 1.0000
Epoch 28/64
12/12 [==============================] - 0s 3ms/step - loss: 0.1038 - accuracy: 0.9780 - val_loss: 0.0210 - val_accuracy: 1.0000
Epoch 29/64
12/12 [==============================] - 0s 3ms/step - loss: 0.0453 - accuracy: 0.9918 - val_loss: 0.0955 - val_accuracy: 0.9778
In [16]:
model.evaluate(Xtest, ytest)
2/2 [==============================] - 0s 1ms/step - loss: 0.0076 - accuracy: 1.0000
Out[16]:
[0.007588881067931652, 1.0]
In [17]:
cm = confusion_matrix(np.argmax(ytest, axis = 1), np.argmax(model.predict(Xtest), axis = 1))
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

for i in range(cm.shape[1]):
    for j in range(cm.shape[0]):
        plt.text(j, i, format(cm[i, j], '.2f'), horizontalalignment="center", color="black")


plt.imshow(cm, cmap=plt.cm.Blues)
Out[17]:
<matplotlib.image.AxesImage at 0x7fc0906ff128>

It is important to keep the accuracy extremely high (100%) as chances cannot be taken with a patient's medication.

Plotting the metrics

In [18]:
def plot(history, variable, variable1):
    plt.plot(range(len(history[variable])), history[variable])
    plt.plot(range(len(history[variable1])), history[variable1])
    plt.title(variable)
    plt.legend([variable, variable1])
In [19]:
plot(history.history, "accuracy", "val_accuracy")
In [20]:
plot(history.history, "loss", "val_loss")

Prediction

In [21]:
gender = ['M', 'F']

def print_sample(x):
    print("\nSample:")
    sample = np.array(test_df)[x]
    print("Age :", sample[0])
    print("Sex :", gender[int(sample[4])])
    print("Na to K ratio :", sample[3])
    print("BP :", le_bp.classes_[int(sample[1])])
    print("Cholesterol :", le_ch.classes_[int(sample[2])])
    print()
In [22]:
# pick random test data sample from one batch
x = random.randint(0, len(Xtest) - 1)

print_sample(x)

output = model.predict(Xtest[x].reshape(1, -1))    # getting output; input shape (256, 256, 3) --> (1, 256, 256, 3)
pred = np.argmax(output[0])    # finding max
print("Predicted: ", output_columns[pred])    # Picking the label from class_names base don the model output

output_true = np.array(ytest)[x]

print("True: ", output_columns[np.argmax(output_true)])
print("Probability: ", output[0][pred])
Sample:
Age : 59.0
Sex : F
Na to K ratio : 13.935
BP : HIGH
Cholesterol : HIGH

Predicted:  drugB
True:  drugB
Probability:  0.99355537

deepC

In [23]:
model.save('drug.h5')

!deepCC drug.h5
[INFO]
Reading [keras model] 'drug.h5'
[SUCCESS]
Saved 'drug_deepC/drug.onnx'
[INFO]
Reading [onnx model] 'drug_deepC/drug.onnx'
[INFO]
Model info:
  ir_vesion : 4
  doc       : 
[WARNING]
[ONNX]: terminal (input/output) dense_input's shape is less than 1. Changing it to 1.
[WARNING]
[ONNX]: terminal (input/output) dense_4's shape is less than 1. Changing it to 1.
WARN (GRAPH): found operator node with the same name (dense_4) as io node.
[INFO]
Running DNNC graph sanity check ...
[SUCCESS]
Passed sanity check.
[INFO]
Writing C++ file 'drug_deepC/drug.cpp'
[INFO]
deepSea model files are ready in 'drug_deepC/' 
[RUNNING COMMAND]
g++ -std=c++11 -O3 -fno-rtti -fno-exceptions -I. -I/opt/tljh/user/lib/python3.7/site-packages/deepC-0.13-py3.7-linux-x86_64.egg/deepC/include -isystem /opt/tljh/user/lib/python3.7/site-packages/deepC-0.13-py3.7-linux-x86_64.egg/deepC/packages/eigen-eigen-323c052e1731 "drug_deepC/drug.cpp" -D_AITS_MAIN -o "drug_deepC/drug.exe"
[RUNNING COMMAND]
size "drug_deepC/drug.exe"
   text	   data	    bss	    dec	    hex	filename
2842085	   2984	    760	2845829	 2b6c85	drug_deepC/drug.exe
[SUCCESS]
Saved model as executable "drug_deepC/drug.exe"
In [24]:
# pick random test data sample from one batch
x = random.randint(0, len(Xtest) - 1)

np.savetxt('sample.data', Xtest[x])    # xth sample into text file

# run exe with input
!drug_deepC/drug.exe sample.data

# show predicted output
nn_out = np.loadtxt('deepSea_result_1.out')

print_sample(x)

pred = np.argmax(nn_out)    # finding max
print("Predicted: ", output_columns[pred])    # Picking the label from class_names base don the model output

output_true = np.array(ytest)[x]

print("True: ", output_columns[np.argmax(output_true)])
print("Probability: ", nn_out[pred])
writing file deepSea_result_1.out.

Sample:
Age : 50.0
Sex : F
Na to K ratio : 7.49
BP : HIGH
Cholesterol : HIGH

Predicted:  drugA
True:  drugA
Probability:  0.999999