Confusion detection with EEG signals¶
Credit: AITS Cainvas Community
Photo by George Vald on Dribbble
Detecting whether a person is confused or not based on the EEG recordings.
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.model_selection import train_test_split
from tensorflow.keras import layers, models, optimizers, losses, callbacks
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
import matplotlib.pyplot as plt
import random
The dataset¶
Wang, H., Li, Y., Hu, X., Yang, Y., Meng, Z., & Chang, K. M. (2013, June). Using EEG to Improve Massive Open Online Courses Feedback Interaction. In AIED Workshops. PDF
EEG signal data was collected from 10 college students while watching MOOC video clips of subjects ranging from simple ones like basic algebra or geometry to Stem Cell research and Quantum Mechanics that can be confusing if we are not familiar with the topic. There were 20 videos, 10 simple ones and 10 complex, each 2 minute long. The clips were copped in the middle of a topic to make it more confusing.
The students wore a single-channel wireless MindSet that measured activity over the frontal lobe. The MindSet measures the voltage between an electrode resting on the forehead and two electrodes (one ground and one reference) each in contact with an ear.
There are two label columns - user-defined label (self labelled by the students based on their experience) and predefined label (where they are expected to be confused).
df = pd.read_csv('https://cainvas-static.s3.amazonaws.com/media/user_data/cainvas-admin/EEG_data.csv')
df
# Defining the time window, that is, how many timesteps to include
time_window = 5
# Dataframes that hold rows grouped by subject
df_subject_grouped = df.groupby('SubjectID')
# Column values affected by time
time_affected_columns = list(df.columns)
time_affected_columns.remove('SubjectID')
time_affected_columns.remove('VideoID')
time_affected_columns.remove('predefinedlabel')
time_affected_columns.remove('user-definedlabeln')
# Final dataframe
df_final = pd.DataFrame()
# For each subject
for subject in df_subject_grouped:
# For each video:
for video in subject[1].groupby('VideoID'):
# If the df has timesteps greater than or equal to the time window, else discard
if time_window <= len(video[1]):
# Skipping time_window-1 rows from the beginning, and looping to till the end
for row_num in range(time_window, len(video[1])+1):
# picking the time_window th row
df_temp = video[1].iloc[row_num-1, :]
# Appending values from time_window-1 rows before that
for i in range(time_window-1):
df_temp_i = video[1].iloc[row_num-1-i][time_affected_columns] # Pick necessary columns
df_temp = pd.concat([df_temp, df_temp_i], axis = 0) # Append values
df_temp = df_temp.to_frame().transpose() # Series to DataFrame
df_final = pd.concat([df_final, df_temp]) # Add as row to final dataframe
# Reset index
df_final = df_final.reset_index(drop = True)
Dropping unwanted columns¶
SubjectID and VideoID should not influence the final results and hence are removed. User defined labels are more reliable in assessing the level of confusion rather than the predefined labels.
df = df_final.drop(columns = ['SubjectID', 'VideoID', 'predefinedlabel'])
df['user-definedlabeln'].value_counts()
This is an almost balanced dataset.
Train-val-test split¶
# Splitting into train, val and test set -- 80-10-10 split
# First, an 80-20 split
train_df, val_test_df = train_test_split(df, test_size = 0.2, random_state = 113)
# Then split the 20% into half
val_df, test_df = train_test_split(val_test_df, test_size = 0.5, random_state = 113)
len(train_df), len(val_df), len(test_df)
ic = df.columns.tolist()
ic.remove('user-definedlabeln')
oc = ['user-definedlabeln']
ytrain = train_df[oc]
Xtrain = train_df.drop(columns = oc)
yval = val_df[oc]
Xval = val_df.drop(columns = oc)
ytest = test_df[oc]
Xtest = test_df.drop(columns = oc)
Standardization¶
Scaling the values to have mean = 0 and standard deviation = 1.
ss = StandardScaler()
Xtrain = ss.fit_transform(Xtrain)
Xval = ss.transform(Xval)
Xtest = ss.transform(Xtest)
The model¶
model = models.Sequential([
layers.Dense(32, activation = 'relu', input_shape = Xtrain[0].shape),
layers.Dropout(0.2),
layers.Dense(16, activation = 'relu'),
layers.Dense(1, activation = 'sigmoid')
])
cb = callbacks.EarlyStopping(patience = 5, restore_best_weights = True)
model.summary()
model.compile(optimizer = optimizers.Adam(0.01), loss = losses.BinaryCrossentropy(), metrics = ['accuracy'])
history = model.fit(Xtrain, ytrain, validation_data = (Xval, yval), epochs = 256, callbacks = cb)
model.evaluate(Xtest, ytest)
cm = confusion_matrix(ytest, (model.predict(Xtest)>0.5).astype('int'))
cm = cm.astype('int') / cm.sum(axis=1)[:, np.newaxis]
fig = plt.figure(figsize = (10, 10))
ax = fig.add_subplot(111)
for i in range(cm.shape[1]):
for j in range(cm.shape[0]):
if cm[i,j] > 0.8:
clr = "white"
else:
clr = "black"
ax.text(j, i, format(cm[i, j], '.2f'), horizontalalignment="center", color=clr)
_ = ax.imshow(cm, cmap=plt.cm.Blues)
ax.set_xticks(range(2))
ax.set_yticks(range(2))
ax.set_xticklabels(range(2), rotation = 90)
ax.set_yticklabels(range(2))
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()
The low accuracy rate may be increased with better labelled data. Self labelled data indicting mental state is easy to be mislabelled.
Plotting the metrics¶
def plot(history, variable, variable2):
plt.plot(range(len(history[variable])), history[variable])
plt.plot(range(len(history[variable2])), history[variable2])
plt.title(variable)
plot(history.history, "accuracy", 'val_accuracy')
plot(history.history, "loss", "val_loss")
Prediction¶
# pick random test data sample from one batch
x = random.randint(0, len(Xtest) - 1)
output = model.predict(Xtest[x].reshape(1, -1))[0][0]
pred = (output>0.5).astype('int')
print("Predicted: ", pred, "(", output, "-->", pred, ")")
print("True: ", np.array(ytest)[x][0])
deepC¶
model.save('confused.h5')
!deepCC confused.h5
# pick random test data sample from one batch
x = random.randint(0, len(Xtest) - 1)
np.savetxt('sample.data', Xtest[x]) # xth sample into text file
# run exe with input
!confused_deepC/confused.exe sample.data
# show predicted output
nn_out = np.loadtxt('deepSea_result_1.out')
pred = (nn_out>0.5).astype('int')
print("Predicted: ", pred, "(", nn_out, "-->", pred, ")")
print("True: ", np.array(ytest)[x][0])