You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

2134 lines
120 KiB

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"#IMPORTS\n",
"\n",
"import numpy as np\n",
"import random\n",
"import tensorflow as tf\n",
"import tensorflow.keras as kr\n",
"import tensorflow.keras.backend as K\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense\n",
"from tensorflow.keras.datasets import mnist\n",
"import os\n",
"import csv\n",
"\n",
"from scipy.spatial.distance import euclidean\n",
"from sklearn.metrics import confusion_matrix\n",
"\n",
"from time import sleep\n",
"from tqdm import tqdm\n",
"\n",
"import copy\n",
"import numpy\n",
"from sklearn.datasets import make_classification\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import math\n",
"import seaborn as sns\n",
"from numpy.random import RandomState\n",
"import scipy as scp\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.preprocessing import OneHotEncoder, LabelEncoder\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense\n",
"from keras import optimizers\n",
"from keras.callbacks import EarlyStopping,ModelCheckpoint\n",
"from keras.utils import to_categorical\n",
"from keras import backend as K\n",
"from itertools import product\n",
"from sklearn.metrics import accuracy_score\n",
"from sklearn.metrics import precision_score\n",
"from sklearn.metrics import recall_score\n",
"from sklearn.metrics import f1_score\n",
"from sklearn.metrics import roc_auc_score\n",
"from sklearn.metrics import confusion_matrix\n",
"\n",
"from sklearn import mixture\n",
"\n",
"from mpl_toolkits.mplot3d import Axes3D\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Enter here the data set you want to explain (adult, activity, or synthatic)\n",
"\n",
"data_set = 'adult'\n",
"\n",
"# Enter here the numb er of peers you want in the experiments\n",
"\n",
"n_peers = 100"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# the random state we will use in the experiments. It can be changed \n",
"\n",
"rs = RandomState(92)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 45222 entries, 0 to 45221\n",
"Data columns (total 14 columns):\n",
"age 45222 non-null float64\n",
"workclass 45222 non-null float64\n",
"educational-num 45222 non-null float64\n",
"marital-status 45222 non-null float64\n",
"occupation 45222 non-null float64\n",
"relationship 45222 non-null float64\n",
"race 45222 non-null float64\n",
"gender 45222 non-null float64\n",
"capital-gain 45222 non-null float64\n",
"capital-loss 45222 non-null float64\n",
"hours-per-week 45222 non-null float64\n",
"native-country 45222 non-null float64\n",
"income_<=50K 45222 non-null uint8\n",
"income_>50K 45222 non-null uint8\n",
"dtypes: float64(12), uint8(2)\n",
"memory usage: 4.2 MB\n"
]
},
{
"data": {
"text/plain": [
"None"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>workclass</th>\n",
" <th>educational-num</th>\n",
" <th>marital-status</th>\n",
" <th>occupation</th>\n",
" <th>relationship</th>\n",
" <th>race</th>\n",
" <th>gender</th>\n",
" <th>capital-gain</th>\n",
" <th>capital-loss</th>\n",
" <th>hours-per-week</th>\n",
" <th>native-country</th>\n",
" <th>income_&lt;=50K</th>\n",
" <th>income_&gt;50K</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.109589</td>\n",
" <td>0.333333</td>\n",
" <td>0.400000</td>\n",
" <td>1.0</td>\n",
" <td>0.461538</td>\n",
" <td>0.6</td>\n",
" <td>0.5</td>\n",
" <td>1.0</td>\n",
" <td>0.000000</td>\n",
" <td>0.0</td>\n",
" <td>0.397959</td>\n",
" <td>0.95</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.287671</td>\n",
" <td>0.333333</td>\n",
" <td>0.533333</td>\n",
" <td>0.0</td>\n",
" <td>0.307692</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.000000</td>\n",
" <td>0.0</td>\n",
" <td>0.500000</td>\n",
" <td>0.95</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.150685</td>\n",
" <td>0.166667</td>\n",
" <td>0.733333</td>\n",
" <td>0.0</td>\n",
" <td>0.769231</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.000000</td>\n",
" <td>0.0</td>\n",
" <td>0.397959</td>\n",
" <td>0.95</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.369863</td>\n",
" <td>0.333333</td>\n",
" <td>0.600000</td>\n",
" <td>0.0</td>\n",
" <td>0.461538</td>\n",
" <td>0.0</td>\n",
" <td>0.5</td>\n",
" <td>1.0</td>\n",
" <td>0.076881</td>\n",
" <td>0.0</td>\n",
" <td>0.397959</td>\n",
" <td>0.95</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.232877</td>\n",
" <td>0.333333</td>\n",
" <td>0.333333</td>\n",
" <td>1.0</td>\n",
" <td>0.538462</td>\n",
" <td>0.2</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.000000</td>\n",
" <td>0.0</td>\n",
" <td>0.295918</td>\n",
" <td>0.95</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>0.630137</td>\n",
" <td>0.666667</td>\n",
" <td>0.933333</td>\n",
" <td>0.0</td>\n",
" <td>0.692308</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.031030</td>\n",
" <td>0.0</td>\n",
" <td>0.316327</td>\n",
" <td>0.95</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>0.095890</td>\n",
" <td>0.333333</td>\n",
" <td>0.600000</td>\n",
" <td>1.0</td>\n",
" <td>0.538462</td>\n",
" <td>0.8</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.000000</td>\n",
" <td>0.0</td>\n",
" <td>0.397959</td>\n",
" <td>0.95</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>0.520548</td>\n",
" <td>0.333333</td>\n",
" <td>0.200000</td>\n",
" <td>0.0</td>\n",
" <td>0.153846</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.000000</td>\n",
" <td>0.0</td>\n",
" <td>0.091837</td>\n",
" <td>0.95</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>0.657534</td>\n",
" <td>0.333333</td>\n",
" <td>0.533333</td>\n",
" <td>0.0</td>\n",
" <td>0.461538</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.064181</td>\n",
" <td>0.0</td>\n",
" <td>0.397959</td>\n",
" <td>0.95</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>0.260274</td>\n",
" <td>0.000000</td>\n",
" <td>0.800000</td>\n",
" <td>0.0</td>\n",
" <td>0.000000</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.000000</td>\n",
" <td>0.0</td>\n",
" <td>0.397959</td>\n",
" <td>0.95</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age workclass educational-num marital-status occupation \\\n",
"0 0.109589 0.333333 0.400000 1.0 0.461538 \n",
"1 0.287671 0.333333 0.533333 0.0 0.307692 \n",
"2 0.150685 0.166667 0.733333 0.0 0.769231 \n",
"3 0.369863 0.333333 0.600000 0.0 0.461538 \n",
"4 0.232877 0.333333 0.333333 1.0 0.538462 \n",
"5 0.630137 0.666667 0.933333 0.0 0.692308 \n",
"6 0.095890 0.333333 0.600000 1.0 0.538462 \n",
"7 0.520548 0.333333 0.200000 0.0 0.153846 \n",
"8 0.657534 0.333333 0.533333 0.0 0.461538 \n",
"9 0.260274 0.000000 0.800000 0.0 0.000000 \n",
"\n",
" relationship race gender capital-gain capital-loss hours-per-week \\\n",
"0 0.6 0.5 1.0 0.000000 0.0 0.397959 \n",
"1 0.0 1.0 1.0 0.000000 0.0 0.500000 \n",
"2 0.0 1.0 1.0 0.000000 0.0 0.397959 \n",
"3 0.0 0.5 1.0 0.076881 0.0 0.397959 \n",
"4 0.2 1.0 1.0 0.000000 0.0 0.295918 \n",
"5 0.0 1.0 1.0 0.031030 0.0 0.316327 \n",
"6 0.8 1.0 0.0 0.000000 0.0 0.397959 \n",
"7 0.0 1.0 1.0 0.000000 0.0 0.091837 \n",
"8 0.0 1.0 1.0 0.064181 0.0 0.397959 \n",
"9 0.0 1.0 1.0 0.000000 0.0 0.397959 \n",
"\n",
" native-country income_<=50K income_>50K \n",
"0 0.95 1 0 \n",
"1 0.95 1 0 \n",
"2 0.95 0 1 \n",
"3 0.95 0 1 \n",
"4 0.95 1 0 \n",
"5 0.95 0 1 \n",
"6 0.95 1 0 \n",
"7 0.95 1 0 \n",
"8 0.95 0 1 \n",
"9 0.95 1 0 "
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# preprocessing adults data set\n",
"\n",
"if data_set == 'adult':\n",
" #Load dataset into a pandas DataFrame\n",
" adult_data = pd.read_csv('adult_data.csv', na_values='?')\n",
" # Drop all records with missing values\n",
" adult_data.dropna(inplace=True)\n",
" adult_data.reset_index(drop=True, inplace=True)\n",
"\n",
" # Drop fnlwgt, not interesting for ML\n",
" adult_data.drop('fnlwgt', axis=1, inplace=True)\n",
" adult_data.drop('education', axis=1, inplace=True)\n",
"\n",
"# merging some similar features.\n",
" adult_data['marital-status'].replace('Married-civ-spouse', 'Married', inplace=True)\n",
" adult_data['marital-status'].replace('Divorced', 'Unmarried', inplace=True)\n",
" adult_data['marital-status'].replace('Never-married', 'Unmarried', inplace=True)\n",
" adult_data['marital-status'].replace('Separated', 'Unmarried', inplace=True)\n",
" adult_data['marital-status'].replace('Widowed', 'Unmarried', inplace=True)\n",
" adult_data['marital-status'].replace('Married-spouse-absent', 'Married', inplace=True)\n",
" adult_data['marital-status'].replace('Married-AF-spouse', 'Married', inplace=True)\n",
" \n",
" adult_data = pd.concat([adult_data,pd.get_dummies(adult_data['income'], prefix='income')],axis=1)\n",
" adult_data.drop('income', axis=1, inplace=True)\n",
" obj_columns = adult_data.select_dtypes(['object']).columns\n",
" adult_data[obj_columns] = adult_data[obj_columns].astype('category')\n",
" # Convert numerics to floats and normalize\n",
" num_columns = adult_data.select_dtypes(['int64']).columns\n",
" adult_data[num_columns] = adult_data[num_columns].astype('float64')\n",
" for c in num_columns:\n",
" #adult[c] -= adult[c].mean()\n",
" #adult[c] /= adult[c].std()\n",
" adult_data[c] = (adult_data[c] - adult_data[c].min()) / (adult_data[c].max()-adult_data[c].min())\n",
" # 'workclass', 'marital-status', 'occupation', 'relationship' ,'race', 'gender', 'native-country'\n",
" # adult_data['income'] = adult_data['income'].cat.codes\n",
" adult_data['marital-status'] = adult_data['marital-status'].cat.codes\n",
" adult_data['occupation'] = adult_data['occupation'].cat.codes\n",
" adult_data['relationship'] = adult_data['relationship'].cat.codes\n",
" adult_data['race'] = adult_data['race'].cat.codes\n",
" adult_data['gender'] = adult_data['gender'].cat.codes\n",
" adult_data['native-country'] = adult_data['native-country'].cat.codes\n",
" adult_data['workclass'] = adult_data['workclass'].cat.codes\n",
"\n",
" num_columns = adult_data.select_dtypes(['int8']).columns\n",
" adult_data[num_columns] = adult_data[num_columns].astype('float64')\n",
" for c in num_columns:\n",
" #adult[c] -= adult[c].mean()\n",
" #adult[c] /= adult[c].std()\n",
" adult_data[c] = (adult_data[c] - adult_data[c].min()) / (adult_data[c].max()-adult_data[c].min())\n",
" display(adult_data.info())\n",
" display(adult_data.head(10))\n",
" \n",
" adult_data = adult_data.to_numpy()\n",
" \n",
"# splite the data to train and test datasets\n",
" X_train, X_test, y_train, y_test = train_test_split(adult_data[:,:-2],adult_data[:,-2:], test_size=0.07, random_state=rs)\n",
"# the names of the features\n",
" names = ['age','workclass','educational-num','marital-status','occupation',\n",
" 'relationship','race','gender','capital-gain','capital-loss','hours-per-week','native-country']\n",
" Features_number = len(X_train[0])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"if data_set == 'synthatic':\n",
" #generate the data\n",
" X, y = make_classification(n_samples=1000000, n_features=10, n_redundant=3, n_repeated=2, #n_classes=3, \n",
" n_informative=5, n_clusters_per_class=4, \n",
" random_state=42)\n",
" y = pd.DataFrame(data=y, columns=[\"y\"])\n",
" y = pd.get_dummies(y['y'], prefix='y')\n",
" y = y.to_numpy()\n",
" X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.07, random_state=rs)\n",
" # the names of the features\n",
" names = ['X(0)','X(1)','X(2)','X(3)','X(4)','X(5)','X(6)','X(7)','X(8)','X(9)']\n",
" Features_number = len(X_train[0])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"if data_set == 'activity':\n",
" #Load dataset into a pandas DataFrame\n",
" activity = pd.read_csv(\"activity_3_original.csv\", sep=',')\n",
"# drop some features that have non value in the majority of the samples\n",
" to_drop = ['subject', 'timestamp', 'heart_rate','activityID']\n",
" activity.drop(axis=1, columns=to_drop, inplace=True)\n",
"# prepare the truth\n",
" activity = pd.concat([activity,pd.get_dummies(activity['motion'], prefix='motion')],axis=1)\n",
" activity.drop('motion', axis=1, inplace=True)\n",
" class_label = [ 'motion_n', 'motion_y']\n",
" predictors = [a for a in activity.columns.values if a not in class_label]\n",
"\n",
" for p in predictors:\n",
" activity[p].fillna(activity[p].mean(), inplace=True)\n",
"\n",
" display(predictors)\n",
" for p in predictors:\n",
" activity[p] = (activity[p]-activity[p].min()) / (activity[p].max() - activity[p].min())\n",
" activity[p].astype('float32')\n",
" activity = activity.to_numpy()\n",
" X_train, X_test, y_train, y_test = train_test_split(activity[:,:-2],activity[:,-2:], test_size=0.07, random_state=rs)\n",
" # the names of the features\n",
" names = ['temp_hand','acceleration_16_x_hand',\n",
" 'acceleration_16_y_hand','acceleration_16_z_hand','acceleration_6_x_hand',\n",
" 'acceleration_6_y_hand','acceleration_6_z_hand','gyroscope_x_hand','gyroscope_y_hand',\n",
" 'gyroscope_z_hand','magnetometer_x_hand','magnetometer_y_hand','magnetometer_z_hand',\n",
" 'temp_chest','acceleration_16_x_chest','acceleration_16_y_chest','acceleration_16_z_chest','acceleration_6_x_chest',\n",
" 'acceleration_6_y_chest','acceleration_6_z_chest','gyroscope_x_chest','gyroscope_y_chest','gyroscope_z_chest',\n",
" 'magnetometer_x_chest','magnetometer_y_chest','magnetometer_z_chest','temp_ankle','acceleration_16_x_ankle',\n",
" 'acceleration_16_y_ankle','acceleration_16_z_ankle','acceleration_6_x_ankle','acceleration_6_y_ankle',\n",
" 'acceleration_6_z_ankle','gyroscope_x_ankle','gyroscope_y_ankle','gyroscope_z_ankle','magnetometer_x_ankle',\n",
" 'magnetometer_y_ankle','magnetometer_z_ankle']\n",
" Features_number = len(X_train[0])\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 42056 samples, validate on 3166 samples\n",
"Epoch 1/2\n",
"42056/42056 [==============================] - ETA: 2:15 - loss: 0.7931 - accuracy: 0.28 - ETA: 5s - loss: 0.5651 - accuracy: 0.6988 - ETA: 3s - loss: 0.5165 - accuracy: 0.72 - ETA: 2s - loss: 0.4896 - accuracy: 0.74 - ETA: 2s - loss: 0.4646 - accuracy: 0.76 - ETA: 2s - loss: 0.4489 - accuracy: 0.77 - ETA: 1s - loss: 0.4412 - accuracy: 0.77 - ETA: 1s - loss: 0.4316 - accuracy: 0.78 - ETA: 1s - loss: 0.4258 - accuracy: 0.78 - ETA: 1s - loss: 0.4224 - accuracy: 0.79 - ETA: 1s - loss: 0.4158 - accuracy: 0.79 - ETA: 1s - loss: 0.4118 - accuracy: 0.79 - ETA: 1s - loss: 0.4073 - accuracy: 0.80 - ETA: 1s - loss: 0.4052 - accuracy: 0.80 - ETA: 1s - loss: 0.4025 - accuracy: 0.80 - ETA: 1s - loss: 0.3992 - accuracy: 0.80 - ETA: 1s - loss: 0.3983 - accuracy: 0.80 - ETA: 0s - loss: 0.3964 - accuracy: 0.80 - ETA: 0s - loss: 0.3930 - accuracy: 0.80 - ETA: 0s - loss: 0.3893 - accuracy: 0.81 - ETA: 0s - loss: 0.3881 - accuracy: 0.81 - ETA: 0s - loss: 0.3873 - accuracy: 0.81 - ETA: 0s - loss: 0.3857 - accuracy: 0.81 - ETA: 0s - loss: 0.3831 - accuracy: 0.81 - ETA: 0s - loss: 0.3811 - accuracy: 0.81 - ETA: 0s - loss: 0.3793 - accuracy: 0.81 - ETA: 0s - loss: 0.3784 - accuracy: 0.81 - ETA: 0s - loss: 0.3764 - accuracy: 0.81 - ETA: 0s - loss: 0.3756 - accuracy: 0.81 - ETA: 0s - loss: 0.3740 - accuracy: 0.82 - ETA: 0s - loss: 0.3726 - accuracy: 0.82 - 2s 41us/step - loss: 0.3720 - accuracy: 0.8217 - val_loss: 0.3452 - val_accuracy: 0.8370\n",
"\n",
"Epoch 00001: val_loss improved from inf to 0.34516, saving model to test.h8\n",
"Epoch 2/2\n",
"42056/42056 [==============================] - ETA: 3s - loss: 0.5095 - accuracy: 0.84 - ETA: 1s - loss: 0.3444 - accuracy: 0.83 - ETA: 1s - loss: 0.3407 - accuracy: 0.84 - ETA: 1s - loss: 0.3383 - accuracy: 0.83 - ETA: 1s - loss: 0.3365 - accuracy: 0.84 - ETA: 1s - loss: 0.3374 - accuracy: 0.84 - ETA: 1s - loss: 0.3386 - accuracy: 0.84 - ETA: 1s - loss: 0.3357 - accuracy: 0.84 - ETA: 1s - loss: 0.3364 - accuracy: 0.84 - ETA: 1s - loss: 0.3353 - accuracy: 0.84 - ETA: 1s - loss: 0.3351 - accuracy: 0.84 - ETA: 1s - loss: 0.3368 - accuracy: 0.83 - ETA: 1s - loss: 0.3381 - accuracy: 0.83 - ETA: 1s - loss: 0.3396 - accuracy: 0.83 - ETA: 0s - loss: 0.3388 - accuracy: 0.83 - ETA: 0s - loss: 0.3384 - accuracy: 0.83 - ETA: 0s - loss: 0.3393 - accuracy: 0.83 - ETA: 0s - loss: 0.3390 - accuracy: 0.83 - ETA: 0s - loss: 0.3392 - accuracy: 0.83 - ETA: 0s - loss: 0.3400 - accuracy: 0.83 - ETA: 0s - loss: 0.3400 - accuracy: 0.83 - ETA: 0s - loss: 0.3410 - accuracy: 0.83 - ETA: 0s - loss: 0.3406 - accuracy: 0.83 - ETA: 0s - loss: 0.3415 - accuracy: 0.83 - ETA: 0s - loss: 0.3419 - accuracy: 0.83 - ETA: 0s - loss: 0.3421 - accuracy: 0.83 - ETA: 0s - loss: 0.3417 - accuracy: 0.83 - ETA: 0s - loss: 0.3410 - accuracy: 0.83 - ETA: 0s - loss: 0.3415 - accuracy: 0.83 - ETA: 0s - loss: 0.3419 - accuracy: 0.83 - ETA: 0s - loss: 0.3420 - accuracy: 0.83 - ETA: 0s - loss: 0.3411 - accuracy: 0.83 - ETA: 0s - loss: 0.3413 - accuracy: 0.83 - 2s 41us/step - loss: 0.3412 - accuracy: 0.8375 - val_loss: 0.3417 - val_accuracy: 0.8370\n",
"\n",
"Epoch 00002: val_loss improved from 0.34516 to 0.34170, saving model to test.h8\n"
]
}
],
"source": [
"#begin federated\n",
"\n",
"earlystopping = EarlyStopping(monitor = 'val_loss',\n",
" min_delta = 0.01,\n",
" patience = 50,\n",
" verbose = 1,\n",
" baseline = 2,\n",
" restore_best_weights = True)\n",
"\n",
"checkpoint = ModelCheckpoint('test.h8',\n",
" monitor='val_loss',\n",
" mode='min',\n",
" save_best_only=True,\n",
" verbose=1)\n",
" \n",
"model = Sequential()\n",
"model.add(Dense(70, input_dim=Features_number, activation='relu'))\n",
"model.add(Dense(50, activation='relu'))\n",
"model.add(Dense(50, activation='relu'))\n",
"model.add(Dense(2, activation='softmax'))\n",
"model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy'])\n",
"history = model.fit(X_train, y_train,\n",
"epochs=2,\n",
"validation_data=(X_test, y_test),\n",
"callbacks = [checkpoint, earlystopping],\n",
"shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"#AUXILIARY METHODS FOR FEDERATED LEARNING\n",
"\n",
"# RETURN INDICES TO LAYERS WITH WEIGHTS AND BIASES\n",
"def trainable_layers(model):\n",
" return [i for i, layer in enumerate(model.layers) if len(layer.get_weights()) > 0]\n",
"\n",
"# RETURN WEIGHTS AND BIASES OF A MODEL\n",
"def get_parameters(model):\n",
" weights = []\n",
" biases = []\n",
" index = trainable_layers(model)\n",
" for i in index:\n",
" weights.append(copy.deepcopy(model.layers[i].get_weights()[0]))\n",
" biases.append(copy.deepcopy(model.layers[i].get_weights()[1])) \n",
" \n",
" return weights, biases\n",
" \n",
"# SET WEIGHTS AND BIASES OF A MODEL\n",
"def set_parameters(model, weights, biases):\n",
" index = trainable_layers(model)\n",
" for i, j in enumerate(index):\n",
" model.layers[j].set_weights([weights[i], biases[i]])\n",
" \n",
"# DEPRECATED: RETURN THE GRADIENTS OF THE MODEL AFTER AN UPDATE \n",
"def get_gradients(model, inputs, outputs):\n",
" \"\"\" Gets gradient of model for given inputs and outputs for all weights\"\"\"\n",
" grads = model.optimizer.get_gradients(model.total_loss, model.trainable_weights)\n",
" symb_inputs = (model._feed_inputs + model._feed_targets + model._feed_sample_weights)\n",
" f = K.function(symb_inputs, grads)\n",
" x, y, sample_weight = model._standardize_user_data(inputs, outputs)\n",
" output_grad = f(x + y + sample_weight)\n",
" \n",
" w_grad = [w for i,w in enumerate(output_grad) if i%2==0]\n",
" b_grad = [w for i,w in enumerate(output_grad) if i%2==1]\n",
" \n",
" return w_grad, b_grad\n",
"\n",
"# RETURN THE DIFFERENCE OF MODELS' WEIGHTS AND BIASES AFTER AN UPDATE \n",
"# NOTE: LEARNING RATE IS APPLIED, SO THE UPDATE IS DIFFERENT FROM THE\n",
"# GRADIENTS. IN CASE VANILLA SGD IS USED, THE GRADIENTS ARE OBTAINED\n",
"# AS (UPDATES / LEARNING_RATE)\n",
"def get_updates(model, inputs, outputs, batch_size, epochs):\n",
" w, b = get_parameters(model)\n",
" #model.train_on_batch(inputs, outputs)\n",
" model.fit(inputs, outputs, batch_size=batch_size, epochs=epochs, verbose=0)\n",
" w_new, b_new = get_parameters(model)\n",
" \n",
" weight_updates = [old - new for old,new in zip(w, w_new)]\n",
" bias_updates = [old - new for old,new in zip(b, b_new)]\n",
" \n",
" return weight_updates, bias_updates\n",
"\n",
"# UPDATE THE MODEL'S WEIGHTS AND PARAMETERS WITH AN UPDATE\n",
"def apply_updates(model, eta, w_new, b_new):\n",
" w, b = get_parameters(model)\n",
" new_weights = [theta - eta*delta for theta,delta in zip(w, w_new)]\n",
" new_biases = [theta - eta*delta for theta,delta in zip(b, b_new)]\n",
" set_parameters(model, new_weights, new_biases)\n",
" \n",
"# FEDERATED AGGREGATION FUNCTION\n",
"def aggregate(n_layers, n_peers, f, w_updates, b_updates):\n",
" agg_w = [f([w_updates[j][i] for j in range(n_peers)], axis=0) for i in range(n_layers)]\n",
" agg_b = [f([b_updates[j][i] for j in range(n_peers)], axis=0) for i in range(n_layers)]\n",
" return agg_w, agg_b\n",
"\n",
"# SOLVE NANS\n",
"def nans_to_zero(W, B):\n",
" W0 = [np.nan_to_num(w, nan=0.0, posinf=0.0, neginf=0.0) for w in W]\n",
" B0 = [np.nan_to_num(b, nan=0.0, posinf=0.0, neginf=0.0) for b in B]\n",
" return W0, B0\n",
"\n",
"def build_forest(X,y):\n",
" clf=RandomForestClassifier(n_estimators=1000, max_depth=7, random_state=0, verbose = 1)\n",
" clf.fit(X,y)\n",
" return clf\n",
"\n",
"# COMPUTE EUCLIDEAN DISTANCE OF WEIGHTS\n",
"def dist_weights(w_a, w_b):\n",
" wf_a = flatten_weights(w_a)\n",
" wf_b = flatten_weights(w_b)\n",
" return euclidean(wf_a, wf_b)\n",
"\n",
"# TRANSFORM ALL WEIGHT TENSORS TO 1D ARRAY\n",
"def flatten_weights(w_in):\n",
" h = w_in[0].reshape(-1)\n",
" for w in w_in[1:]:\n",
" h = np.append(h, w.reshape(-1))\n",
" return h\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# scan the forest for trees maches the wrong predictions of the black-box\n",
"def scan_wrong(forest_predictions, FL_predict1, forest , y_test_local, X_test_local):\n",
" sum_feature_improtance= 0\n",
" overal_wrong_feature_importance = 0\n",
" counter = 0\n",
" second_counter = 0\n",
" never_seen = 0\n",
" avr_wrong_importance = 0\n",
" FL_predict1 = np.argmax(FL_predict1, axis=1)\n",
" forest_predictions = np.argmax(forest_predictions, axis=1)\n",
" y_test_local = np.argmax(y_test_local, axis=1)\n",
" for i in range (len(FL_predict1)):\n",
" i_tree = 0\n",
"# if the black-box got a wrong prediction\n",
" if (FL_predict1[i] != y_test_local[i]):\n",
"# getting the prediction of the trees one by one\n",
" for tree_in_forest in forest.estimators_:\n",
" sample = X_test_local[i].reshape(1, -1)\n",
" temp = forest.estimators_[i_tree].predict(sample)\n",
" temp = np.argmax(temp, axis=1)\n",
" i_tree = i_tree + 1\n",
"# if the prediction of the tree maches the predictions of the black-box\n",
" if(FL_predict1[i] == temp):\n",
"# getting the features importances\n",
" sum_feature_improtance = sum_feature_improtance + tree_in_forest.feature_importances_\n",
" counter = counter + 1\n",
"# if we have trees maches the black-box predictions\n",
" if(counter>0):\n",
" ave_feature_importence = sum_feature_improtance/counter\n",
" overal_wrong_feature_importance = ave_feature_importence + overal_wrong_feature_importance\n",
" second_counter = second_counter + 1\n",
" counter = 0\n",
" sum_feature_improtance = 0\n",
"# if there is no trees maches the black-box predictions\n",
" else:\n",
" if(FL_predict1[i] != y_test_local[i]):\n",
" never_seen = never_seen +1\n",
"# getting the average features importances for all the samples that had wrong predictions.\n",
" if(second_counter>0):\n",
" avr_wrong_importance = overal_wrong_feature_importance / second_counter\n",
" return forest.feature_importances_"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0, 1, 2, 3]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainable_layers(model)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"([array([[ 1.39432400e-01, 8.84631574e-02, -4.47415888e-01,\n",
" 1.23670131e-01, -2.65049934e-01, 2.56673127e-01,\n",
" 2.82177985e-01, -3.88451487e-01, -8.48813355e-02,\n",
" -4.55360711e-01, -2.55180508e-01, -1.34169891e-01,\n",
" -4.19932574e-01, 9.50885192e-02, -4.10533138e-02,\n",
" 1.23161055e-01, -3.34913731e-01, -3.29331495e-02,\n",
" -2.09537312e-01, 2.89370805e-01, -2.42182449e-01,\n",
" 9.41318497e-02, -8.54814351e-02, -2.53278345e-01,\n",
" 7.38841221e-02, 9.76254940e-02, 9.64644551e-03,\n",
" 4.62163612e-02, 1.47847623e-01, 3.28071006e-02,\n",
" -2.16738522e-01, -5.52587435e-02, -1.01704948e-01,\n",
" 2.31297538e-01, -3.01694840e-01, 2.23755836e-02,\n",
" -2.37541839e-01, -8.33741352e-02, -3.33046556e-01,\n",
" -3.82800475e-02, -2.60576427e-01, 1.35413051e-01,\n",
" 5.84374070e-02, -1.67372033e-01, 7.50956163e-02,\n",
" -2.44477212e-01, 4.34608996e-01, 1.95100769e-01,\n",
" -1.71157598e-01, 2.94538945e-01, -2.78368771e-01,\n",
" 3.23733628e-01, 7.93107301e-02, 2.28328109e-01,\n",
" 6.06352724e-02, -7.03767091e-02, 1.33410409e-01,\n",
" 1.21751621e-01, 1.97286800e-01, -9.21699479e-02,\n",
" -3.15490931e-01, 2.30563477e-01, -3.28507647e-02,\n",
" 8.77456143e-02, 5.48780151e-02, -4.60406430e-02,\n",
" -1.89183086e-01, 3.93763036e-02, 2.96199113e-01,\n",
" -2.79987492e-02],\n",
" [ 3.03673856e-02, 1.97539851e-02, -1.50838614e-01,\n",
" 1.14162855e-01, 1.80196881e-01, -1.22831225e-01,\n",
" 7.67074972e-02, -1.13835640e-01, -1.38265222e-01,\n",
" -6.62374571e-02, 1.81205988e-01, -2.81262010e-01,\n",
" 1.72400191e-01, 2.07341984e-01, -1.34065270e-01,\n",
" 6.87680393e-02, -6.93561733e-02, -2.21116617e-01,\n",
" 1.04925461e-01, 1.02081522e-02, 1.51008025e-01,\n",
" -2.92544812e-02, -1.05958931e-01, 1.61262244e-01,\n",
" 1.58383980e-01, -1.24027103e-01, -1.80273309e-01,\n",
" 2.02690706e-01, 1.30619720e-01, -1.44045368e-01,\n",
" 5.87314926e-02, -6.84582517e-02, 5.60571887e-02,\n",
" -1.27603471e-01, 2.20635161e-01, 1.71862170e-01,\n",
" 1.77298188e-02, 1.31710157e-01, -2.06363559e-01,\n",
" 1.41939849e-01, 4.67592143e-02, -2.25164890e-01,\n",
" 2.84170844e-02, -1.87025517e-01, 2.21437346e-02,\n",
" 2.89680868e-01, 2.44593516e-01, 5.39705567e-02,\n",
" 1.68798208e-01, -9.17015448e-02, -9.46003050e-02,\n",
" -8.50451589e-02, -9.65483636e-02, 2.15933964e-01,\n",
" 3.86347598e-03, -2.29437221e-02, 8.44280720e-02,\n",
" 1.96231887e-01, 3.78342345e-02, 1.12372516e-02,\n",
" 7.45132491e-02, -1.45243943e-01, 1.38520822e-01,\n",
" 1.27623096e-01, 9.93933976e-02, 7.73796961e-02,\n",
" 1.07909396e-01, 5.35671674e-02, -2.25077912e-01,\n",
" 1.48774251e-01],\n",
" [-2.48966157e-01, -1.18819617e-01, 3.78526822e-02,\n",
" -4.11971584e-02, 5.32225370e-02, 2.79902488e-01,\n",
" 3.43969136e-01, -5.78653142e-02, 1.67140678e-01,\n",
" -2.94734612e-02, 1.13698818e-01, -2.92426739e-02,\n",
" -1.79812416e-01, 2.88941506e-02, -1.41450733e-01,\n",
" -7.92392809e-03, -1.35528877e-01, -2.56182522e-01,\n",
" -2.33598545e-01, -5.47329225e-02, 2.58110791e-01,\n",
" -2.45282829e-01, 4.75647040e-02, 4.78960238e-02,\n",
" -6.56322390e-02, -6.67297915e-02, 1.69852525e-01,\n",
" -1.50414899e-01, 2.58721203e-01, 1.14194579e-01,\n",
" 2.65164256e-01, 8.89386758e-02, 2.67333359e-01,\n",
" -3.09747636e-01, -1.52420253e-02, 2.57288337e-01,\n",
" 1.46575630e-01, 8.43582675e-02, 1.89198285e-01,\n",
" -5.13301976e-02, -1.45431489e-01, 1.83323875e-01,\n",
" 2.22104147e-01, -7.55850300e-02, 1.44288674e-01,\n",
" -1.75847083e-01, -1.43846169e-01, 1.33877620e-01,\n",
" 1.63822114e-01, -1.28378317e-01, -2.10838597e-02,\n",
" -2.69852519e-01, 1.04066990e-01, 2.06833377e-01,\n",
" -1.28662705e-01, 1.49911791e-01, -2.75938064e-01,\n",
" -3.31552997e-02, 2.19017982e-01, 6.46202068e-04,\n",
" 1.66913256e-01, -1.72089741e-01, 9.96593982e-02,\n",
" -2.43812397e-01, -8.03031027e-02, -1.92508698e-01,\n",
" -3.14832121e-01, -9.16534588e-02, -3.15453112e-01,\n",
" 1.48415402e-01],\n",
" [-2.35771656e-01, 3.27018127e-02, 1.60873935e-01,\n",
" -1.28616795e-01, 3.11803758e-01, -2.35472228e-02,\n",
" -1.39719948e-01, 1.74694061e-02, 7.51914829e-02,\n",
" 2.35624880e-01, 7.33765140e-02, 2.16503426e-01,\n",
" 4.06566672e-02, -2.05656707e-01, 1.96258724e-01,\n",
" 5.99774197e-02, -1.27538797e-02, -6.30170330e-02,\n",
" -1.16274104e-01, -1.43104732e-01, -1.37973130e-01,\n",
" -1.91767380e-01, 3.22461128e-01, 2.99887396e-02,\n",
" 2.64688015e-01, -2.45580390e-01, -2.41390377e-01,\n",
" -1.29994661e-01, -1.80605844e-01, -2.61187732e-01,\n",
" 1.44567609e-01, 1.88110307e-01, 1.73101038e-01,\n",
" 2.86840070e-02, -1.33754045e-01, -5.33887371e-02,\n",
" -1.13288000e-01, -8.15718770e-02, 2.53453523e-01,\n",
" -1.54690027e-01, -1.32443011e-02, -6.94180205e-02,\n",
" -1.20536266e-02, -2.19712891e-02, -2.30549023e-01,\n",
" 2.46970072e-01, -1.82330459e-02, -1.24268174e-01,\n",
" 2.66243219e-01, 1.11885495e-01, 8.33856687e-02,\n",
" -1.06503241e-01, -2.80220248e-02, -1.17930442e-01,\n",
" 2.08708122e-01, 7.04001710e-02, -1.37973502e-02,\n",
" 1.89776018e-01, -7.30874389e-02, -2.11521506e-01,\n",
" 1.42071024e-01, 2.42409576e-02, 8.69186819e-02,\n",
" 3.34844626e-02, -2.07044452e-01, -1.04645088e-01,\n",
" 1.51515082e-01, -1.95780490e-02, 2.13911623e-01,\n",
" 9.59823653e-02],\n",
" [-2.26251304e-01, -4.98282760e-02, 8.57945010e-02,\n",
" 1.85095415e-01, 1.94030240e-01, 1.70300901e-01,\n",
" -1.48310944e-01, -1.68697998e-01, 1.38381734e-01,\n",
" -8.20567235e-02, 1.35808028e-02, -1.75055087e-01,\n",
" 2.08388101e-02, -2.22936451e-01, -7.68952891e-02,\n",
" -4.24526669e-02, 4.03720774e-02, 2.34893888e-01,\n",
" -1.57926619e-01, -2.40865514e-01, 1.67401552e-01,\n",
" 2.16235057e-01, -1.50564939e-01, 1.77459866e-01,\n",
" -1.02011845e-01, 9.56041086e-03, -1.36439502e-01,\n",
" 1.67499810e-01, 1.46594793e-01, -2.37665162e-03,\n",
" 2.35330492e-01, -4.87338640e-02, -8.25209543e-02,\n",
" -7.34776333e-02, 2.11637601e-01, -8.63815099e-02,\n",
" -2.52601802e-01, -1.03249528e-01, 1.14807218e-01,\n",
" 1.93410560e-01, -7.48374164e-02, 4.09806073e-02,\n",
" -1.25015989e-01, 1.75860271e-01, 1.65006757e-01,\n",
" 1.63865000e-01, 1.56919926e-01, -2.22888529e-01,\n",
" -3.29164751e-02, 4.06037048e-02, 2.24684268e-01,\n",
" 1.01046182e-01, -1.53632820e-01, -1.65310353e-01,\n",
" 4.86176573e-02, -2.46649399e-01, -2.84075760e-03,\n",
" 1.55264661e-01, 4.27330621e-02, -2.05510065e-01,\n",
" 1.62713528e-01, -3.14808562e-02, 1.86110288e-01,\n",
" 6.84845075e-02, 4.47224490e-02, -3.40451181e-01,\n",
" 1.40326787e-02, 2.19547436e-01, 7.52496868e-02,\n",
" 1.09770238e-01],\n",
" [-2.14519277e-01, -1.97733060e-01, -1.04191333e-01,\n",
" -1.52826672e-02, 1.04496861e-03, -6.56969398e-02,\n",
" -7.04714730e-02, -1.19291015e-01, 1.01761602e-01,\n",
" 7.52121955e-02, -2.15532720e-01, -1.47176266e-01,\n",
" 1.51603609e-01, -1.83050726e-02, -3.25457342e-02,\n",
" -5.11338934e-02, 1.16198196e-03, -2.66087204e-01,\n",
" 7.53995031e-02, 7.98415840e-02, 4.19246480e-02,\n",
" -7.96627849e-02, 1.22839414e-01, 1.80793643e-01,\n",
" -2.73334742e-01, 5.54925241e-02, 1.19968027e-01,\n",
" 1.63323641e-01, 1.11940101e-01, -1.46585837e-01,\n",
" 1.94005132e-01, 1.88561931e-01, -5.62924668e-02,\n",
" -4.18225750e-02, -1.56423241e-01, -2.25715101e-01,\n",
" -4.82656956e-02, 2.14031748e-02, 2.10182130e-01,\n",
" -3.18871409e-01, -7.38589093e-02, -2.32924759e-01,\n",
" 8.74556080e-02, -1.10086516e-01, 1.84157446e-01,\n",
" -1.46957889e-01, -1.06122330e-01, 2.88575172e-01,\n",
" 7.43130967e-02, 1.63028061e-01, 2.40940854e-01,\n",
" 8.84263813e-02, 1.86871052e-01, -1.03018314e-01,\n",
" -2.51245052e-02, -2.32590944e-01, 2.58567259e-02,\n",
" 1.24988005e-01, 4.27892543e-02, 6.42778203e-02,\n",
" 2.41022035e-01, -5.46587259e-02, -1.77857980e-01,\n",
" 3.70368622e-02, 2.42744144e-02, 1.84613451e-01,\n",
" 2.30415717e-01, -1.80632919e-01, -9.84579027e-02,\n",
" -4.87778150e-02],\n",
" [-2.97077070e-03, -9.92525965e-02, 9.59780440e-02,\n",
" -1.05714351e-01, -2.09908143e-01, 2.08500147e-01,\n",
" -9.31153223e-02, 2.99151987e-01, 4.34016176e-02,\n",
" -2.24611446e-01, 3.31769064e-02, 2.14490488e-01,\n",
" -2.24754527e-01, -1.74998924e-01, -4.15243544e-02,\n",
" -1.69698030e-01, 2.80564696e-01, 1.17882535e-01,\n",
" -9.80678648e-02, 3.15327570e-03, -2.08990425e-01,\n",
" 1.49431065e-01, -1.39306724e-01, 2.40346678e-02,\n",
" 2.40564555e-01, -5.09837978e-02, 2.17804000e-01,\n",
" 1.35088935e-01, 8.79955664e-02, -5.64928725e-02,\n",
" 4.61013429e-02, 6.54249862e-02, -8.42749923e-02,\n",
" 2.62729824e-01, -3.99206020e-02, -1.17483221e-01,\n",
" -1.40452668e-01, -1.06828704e-01, -1.74000204e-01,\n",
" -4.49550189e-02, 2.60878950e-01, 2.07423091e-01,\n",
" -9.15924609e-02, 1.91001654e-01, -1.47255644e-01,\n",
" 7.95471966e-02, -1.70050204e-01, 5.61165512e-02,\n",
" -1.48466706e-01, 1.08682081e-01, 2.04737335e-02,\n",
" -1.74528554e-01, -9.47896019e-02, 1.73530400e-01,\n",
" -1.12356387e-01, 9.92965326e-02, 1.26004890e-01,\n",
" -2.32813179e-01, 9.49711502e-02, -2.34253883e-01,\n",
" -2.76989549e-01, -7.66268969e-02, -3.41671556e-02,\n",
" -5.10511408e-03, -5.79159260e-02, -6.46380782e-02,\n",
" -5.50055876e-02, -3.11404735e-01, 2.45275497e-01,\n",
" -2.22187296e-01],\n",
" [ 4.54801060e-02, 2.56455511e-01, -1.82633027e-01,\n",
" -1.01602580e-02, -8.93032998e-02, 1.04237944e-01,\n",
" 5.84088564e-02, 1.54823989e-01, -1.07336426e-02,\n",
" 2.69688278e-01, 6.16033142e-03, -6.09616982e-03,\n",
" 8.98296311e-02, 1.78536490e-01, -1.43777172e-03,\n",
" -9.94328558e-02, -4.55807038e-02, 9.91010815e-02,\n",
" -4.42102812e-02, 3.77892517e-02, 1.33471981e-01,\n",
" 7.44501278e-02, 1.62690468e-02, 2.23104075e-01,\n",
" -2.61054993e-01, 3.15811366e-01, -2.96082795e-01,\n",
" 1.78025752e-01, -2.63285220e-01, -5.37474826e-02,\n",
" -9.58651751e-02, -2.15012103e-01, -4.33603339e-02,\n",
" -2.60652751e-01, -5.41594252e-02, 2.35952377e-01,\n",
" -3.74763012e-02, 1.91953376e-01, 1.17158510e-01,\n",
" 3.78518994e-03, -6.19563572e-02, 2.10780635e-01,\n",
" 1.62149847e-01, -1.30085796e-01, 1.28252106e-03,\n",
" 2.28483707e-01, -1.14689972e-02, -8.24389532e-02,\n",
" -1.77851245e-01, -1.37649611e-01, 1.65123567e-01,\n",
" 1.03654794e-01, 8.36220309e-02, 1.99557766e-02,\n",
" 6.00132421e-02, -3.04210056e-02, -2.81973660e-01,\n",
" -2.42123492e-02, -2.17434868e-01, -9.64278206e-02,\n",
" -1.85030416e-01, -2.62960136e-01, 5.34782112e-02,\n",
" 1.58508420e-01, 1.65380761e-01, -3.85079943e-02,\n",
" 2.55265355e-01, 5.09922206e-02, -1.47566527e-01,\n",
" 7.40251169e-02],\n",
" [ 1.29649222e-01, -3.14282179e-02, -6.06167972e-01,\n",
" -2.50955880e-01, -3.46874207e-01, 7.49993503e-01,\n",
" 7.28010595e-01, -1.06399655e+00, 1.06234324e+00,\n",
" -3.55233133e-01, -5.50140023e-01, 1.00409508e+00,\n",
" -5.45210958e-01, 1.93181217e-01, -7.01776028e-01,\n",
" -2.10634783e-01, -4.23527777e-01, 3.09440106e-01,\n",
" -1.91907719e-01, 2.85458267e-01, 7.82932997e-01,\n",
" -5.32808244e-01, -4.39185768e-01, -7.65542090e-01,\n",
" -3.82927716e-01, -1.15567505e+00, 1.67764112e-01,\n",
" -9.48192775e-01, 3.64812821e-01, 2.28667915e-01,\n",
" 6.75961256e-01, 8.27623010e-01, -6.38736844e-01,\n",
" -2.00036347e-01, -3.25849533e-01, 9.03906941e-01,\n",
" -2.68816352e-01, -6.27302647e-01, -3.23336124e-01,\n",
" 4.70992297e-01, -5.73931932e-01, 9.17997599e-01,\n",
" 7.42488205e-01, -2.06164107e-01, 2.04111740e-01,\n",
" -7.19973087e-01, -3.76782537e-01, 8.55549395e-01,\n",
" -8.38361323e-01, -9.57333803e-01, -5.20633638e-01,\n",
" -3.67659301e-01, 1.50605768e-01, -7.64182091e-01,\n",
" -3.19448918e-01, -5.01123592e-02, 1.64251193e-01,\n",
" -7.17021644e-01, 6.97100699e-01, -1.25111267e-01,\n",
" -4.96421248e-01, -5.05610764e-01, 1.01232016e+00,\n",
" -1.09313202e+00, 3.20109189e-01, -1.06782168e-01,\n",
" -9.03548539e-01, -2.81452984e-01, -2.17785537e-01,\n",
" 6.68265998e-01],\n",
" [-2.42571607e-01, 2.04211175e-01, -4.92268875e-02,\n",
" -1.63620815e-01, 4.04583551e-02, 3.45696330e-01,\n",
" 3.54173370e-02, 8.10830146e-02, -1.61551312e-03,\n",
" -1.48698622e-02, 1.94258001e-02, 1.15005746e-01,\n",
" -1.20848659e-02, 2.12298751e-01, 8.92769620e-02,\n",
" -7.64900148e-02, -3.41445431e-02, 7.51630887e-02,\n",
" -3.40494029e-02, 2.70350277e-01, 4.42853682e-02,\n",
" 5.13006859e-02, 2.81202555e-01, 1.35484681e-01,\n",
" 7.24686086e-02, -1.34075984e-01, 1.70696169e-01,\n",
" 8.00305977e-03, 2.56366223e-01, 1.33748680e-01,\n",
" 2.25041300e-01, -7.13687986e-02, -4.96987440e-02,\n",
" 3.10503058e-02, -2.25651234e-01, 3.94519985e-01,\n",
" 2.15304196e-01, 1.15869548e-02, 1.47072956e-01,\n",
" 3.18337977e-01, -6.86229253e-03, -5.09570874e-02,\n",
" 2.29824454e-01, 2.61031240e-02, 2.89728433e-01,\n",
" 3.48875783e-02, -5.50319031e-02, 1.21588549e-02,\n",
" -4.12969440e-02, 1.07327215e-01, 1.35437414e-01,\n",
" -2.93096341e-02, 3.36093381e-02, -1.90401971e-01,\n",
" -2.66215026e-01, 6.08073771e-02, 6.91038072e-02,\n",
" -1.98440487e-03, 7.31287152e-02, -2.77851731e-01,\n",
" -1.08341835e-01, -1.85085818e-01, 2.29901448e-01,\n",
" -2.96091676e-01, 2.23246500e-01, -1.44393981e-01,\n",
" -1.93921745e-01, -1.92566663e-01, 1.32529914e-01,\n",
" -1.94337085e-01],\n",
" [-2.53917843e-01, -2.51892120e-01, -1.32432416e-01,\n",
" 1.47464365e-01, -3.17318618e-01, 1.97301418e-01,\n",
" 2.69987226e-01, -4.56497446e-02, 2.30195507e-01,\n",
" 1.32218450e-02, -4.06064779e-01, 2.51328260e-01,\n",
" 5.33021428e-02, -2.66608417e-01, -1.79995075e-01,\n",
" 1.29986405e-01, 2.86205828e-01, 1.35580912e-01,\n",
" -2.18271971e-01, 1.57579169e-01, 2.17058808e-01,\n",
" -1.03528440e-01, -4.87327874e-02, -6.85375035e-02,\n",
" -1.29330382e-01, -1.23507090e-01, 4.80753556e-02,\n",
" -3.16315889e-01, 2.06642285e-01, -1.25930071e-01,\n",
" 4.74674441e-03, -2.28398144e-01, 1.07306920e-01,\n",
" 2.11515024e-01, -1.60666719e-01, 1.23706006e-01,\n",
" -2.24141285e-01, 5.61789013e-02, 1.76867880e-02,\n",
" -9.17073190e-02, 8.19897652e-02, -1.55695155e-02,\n",
" 3.17650735e-01, 2.38761097e-01, 2.45510742e-01,\n",
" 8.75920355e-02, 3.21321398e-01, -1.12799473e-01,\n",
" 2.10606474e-02, -1.81161851e-01, 2.40284592e-01,\n",
" -2.50088274e-01, -2.18976215e-01, 1.12234220e-01,\n",
" 4.77548651e-02, -4.73017395e-02, 1.37630356e-02,\n",
" 1.92280307e-01, 7.14965388e-02, -6.21563159e-02,\n",
" -7.16416389e-02, 1.23388998e-01, 1.82368487e-01,\n",
" 2.31735930e-01, -2.02105597e-01, 1.42061830e-01,\n",
" 1.23616353e-01, 1.56008020e-01, -2.19544828e-01,\n",
" 2.12301493e-01],\n",
" [-7.18890205e-02, -1.92233965e-01, 2.33305559e-01,\n",
" 6.87015578e-02, 8.51642191e-02, -2.19545767e-01,\n",
" 6.00749105e-02, 3.61590572e-02, -6.68269545e-02,\n",
" 1.48855716e-01, 2.30278343e-01, 2.16507941e-01,\n",
" 2.22660348e-01, -2.84734219e-01, 2.37847969e-01,\n",
" -1.55460656e-01, -2.26989180e-01, -6.12188876e-02,\n",
" 1.77810416e-01, 1.45450696e-01, 2.52608925e-01,\n",
" -1.36337921e-01, -1.94631949e-01, 1.51410148e-01,\n",
" 3.44162211e-02, 1.61046118e-01, -1.24759860e-01,\n",
" 1.83450043e-01, 1.75598450e-02, -2.05802217e-01,\n",
" -8.66022483e-02, 6.08737469e-02, -2.22572535e-01,\n",
" -1.20819479e-01, -1.68945014e-01, 2.10285246e-01,\n",
" -2.40360171e-01, -1.79741889e-01, -1.93881094e-01,\n",
" 1.32005673e-03, -8.93675536e-02, -1.65670961e-01,\n",
" -8.00130144e-02, -2.01122567e-01, 1.55159965e-01,\n",
" 4.84559573e-02, -1.92197278e-01, 1.46897465e-01,\n",
" -1.71061575e-01, 5.79360016e-02, -1.45457163e-01,\n",
" 1.78534076e-01, 1.95346713e-01, -9.44947526e-02,\n",
" -2.78981924e-01, -1.16451114e-01, 1.21675292e-02,\n",
" -1.05452980e-03, 2.97299847e-02, 1.15553983e-01,\n",
" -1.47618756e-01, 2.83984572e-01, -9.44054872e-02,\n",
" -6.82652295e-02, 1.54531911e-01, -9.11844522e-02,\n",
" 2.69836523e-02, -3.09856743e-01, 6.67436346e-02,\n",
" 2.40427703e-01]], dtype=float32),\n",
" array([[ 0.04073317, -0.16170031, 0.08170982, ..., 0.12143514,\n",
" -0.03804543, -0.1848121 ],\n",
" [-0.02006259, 0.04184515, 0.20358184, ..., 0.08938669,\n",
" 0.02554417, -0.0998741 ],\n",
" [-0.05848309, -0.13393435, 0.28651938, ..., -0.19336581,\n",
" 0.28697622, -0.18376462],\n",
" ...,\n",
" [-0.13138615, -0.10152157, 0.05253223, ..., 0.16827357,\n",
" 0.09525165, 0.17411834],\n",
" [-0.00976845, -0.10780089, 0.2228816 , ..., 0.1733975 ,\n",
" -0.10156322, 0.03318954],\n",
" [ 0.09590832, -0.01828083, 0.12743485, ..., 0.25016934,\n",
" 0.12800731, -0.10581163]], dtype=float32),\n",
" array([[-0.00384881, -0.12021059, 0.01248708, ..., 0.01682259,\n",
" -0.17754331, 0.02930963],\n",
" [-0.03520177, 0.0117013 , 0.03343487, ..., -0.16231427,\n",
" 0.1756002 , 0.00351096],\n",
" [-0.1752005 , 0.004585 , -0.11959553, ..., -0.17236647,\n",
" 0.28346488, 0.26809448],\n",
" ...,\n",
" [ 0.01488994, 0.00250473, -0.25695267, ..., -0.11059541,\n",
" 0.17581026, -0.23348542],\n",
" [ 0.21297403, 0.24602796, 0.06359419, ..., 0.205567 ,\n",
" 0.04510517, 0.11687386],\n",
" [-0.17597616, 0.07059528, 0.10327347, ..., -0.02315794,\n",
" 0.00959007, -0.01356981]], dtype=float32),\n",
" array([[-2.82930046e-01, 1.26908660e-01],\n",
" [ 2.37486243e-01, -3.81716669e-01],\n",
" [ 9.92978290e-02, -3.47963899e-01],\n",
" [ 3.02352726e-01, -3.74164760e-01],\n",
" [-2.05417976e-01, 2.52470911e-01],\n",
" [-4.55201864e-02, -2.02432677e-01],\n",
" [ 1.73006430e-01, -4.46816646e-02],\n",
" [-2.84130216e-01, -2.26977065e-01],\n",
" [-4.35910234e-03, 3.76744062e-01],\n",
" [ 1.45330116e-01, -3.25348943e-01],\n",
" [ 2.28147835e-01, -2.77784109e-01],\n",
" [-1.19501755e-01, 4.07545753e-02],\n",
" [ 1.01264335e-01, -2.43342578e-01],\n",
" [-1.60477936e-01, 5.24386704e-01],\n",
" [ 6.06849305e-02, 9.89513546e-02],\n",
" [-2.89398909e-01, 1.83537304e-01],\n",
" [ 1.01001307e-01, 2.95499355e-01],\n",
" [-2.97017217e-01, 3.22097719e-01],\n",
" [ 1.97861195e-01, -2.02269956e-01],\n",
" [-7.52512068e-02, 9.88621786e-02],\n",
" [-1.38137221e-01, -4.40452248e-01],\n",
" [-2.33402535e-01, 1.64692864e-01],\n",
" [ 1.27101064e-01, 1.98759794e-01],\n",
" [-3.01784992e-01, 2.12811917e-01],\n",
" [-1.96352318e-01, 1.54295802e-01],\n",
" [ 2.49975443e-01, -2.01082289e-01],\n",
" [-1.38984874e-01, 2.29037121e-01],\n",
" [ 1.06105595e-04, 2.89339125e-01],\n",
" [-3.00384670e-01, -4.83968072e-02],\n",
" [-1.86271910e-02, 3.05029899e-01],\n",
" [ 1.99009106e-03, 2.14236692e-01],\n",
" [-3.34532440e-01, -6.11541159e-02],\n",
" [-1.35282487e-01, 5.65957166e-02],\n",
" [-2.68078923e-01, 1.56603098e-01],\n",
" [-2.48180240e-01, 2.94318020e-01],\n",
" [-1.30787000e-01, -4.86690477e-02],\n",
" [ 2.72127450e-01, 1.19140044e-01],\n",
" [ 2.90248722e-01, -1.86683103e-01],\n",
" [ 9.85520706e-02, -2.18973175e-01],\n",
" [-3.67538421e-03, -1.40206725e-03],\n",
" [ 2.09687546e-01, 2.38504097e-01],\n",
" [-1.00464404e-01, 2.42502570e-01],\n",
" [ 1.55400500e-01, 1.01416796e-01],\n",
" [-3.02865952e-01, 7.42565766e-02],\n",
" [-1.75600335e-01, 3.34860444e-01],\n",
" [-1.38489887e-01, 2.21242890e-01],\n",
" [ 1.80740595e-01, 2.85507560e-01],\n",
" [ 2.81139612e-01, -1.64098963e-01],\n",
" [ 1.56524777e-01, -3.87664348e-01],\n",
" [ 4.04402643e-01, 3.33227254e-02]], dtype=float32)],\n",
" [array([-0.02854579, 0.05311754, 0.04959053, -0.03715162, 0.02330466,\n",
" -0.04269688, -0.01013005, 0.06874033, -0.01974296, 0.06377108,\n",
" 0.01915232, -0.02997592, 0.04385599, -0.03190788, 0.0822746 ,\n",
" -0.0659938 , 0.05614723, -0.05916027, 0.00469705, -0.07926938,\n",
" -0.04140961, 0.04683878, 0.05930374, 0.03659062, 0.01063251,\n",
" 0.08476436, -0.05785139, 0.06547377, -0.06716973, -0.03636409,\n",
" -0.07686226, -0.06849793, 0.05797694, 0.02435303, 0.03509652,\n",
" -0.05251152, 0. , 0.08996248, 0.03742762, -0.01523749,\n",
" 0.03550566, -0.03804714, -0.00484502, -0.00379997, -0.08697824,\n",
" 0.05449862, 0.0753291 , -0.04323955, 0.06112291, 0.07690092,\n",
" -0.0228012 , 0.0299318 , -0.00519692, -0.02068757, -0.00509608,\n",
" -0.05934853, 0.01074862, 0.04968895, -0.06797037, -0.03621136,\n",
" 0.04376265, 0.05112915, -0.01860299, 0.1174278 , -0.07478895,\n",
" -0.03162573, 0.03712742, -0.02407979, 0.01463216, -0.06149809],\n",
" dtype=float32),\n",
" array([-0.00573702, 0.02881279, 0.08348639, 0.05104946, 0.00677452,\n",
" 0.07581051, -0.00896564, 0.05683671, -0.04194446, -0.03258895,\n",
" 0.02584886, -0.01876919, -0.03003295, 0.06011844, -0.03459567,\n",
" -0.02480574, -0.02663354, -0.01426572, 0.04364663, 0.08130559,\n",
" 0.06495807, -0.04167927, 0.05315804, 0.06356885, -0.0253933 ,\n",
" 0.04731547, -0.05747719, 0.05308496, -0.05659075, 0.05597835,\n",
" -0.04872588, -0.04101903, 0.03206719, 0.05267171, -0.05795552,\n",
" 0.02027026, -0.05280019, 0.05442906, -0.01780812, 0.06429685,\n",
" -0.04762284, 0.06730605, -0.05093784, 0.05504497, 0.05357518,\n",
" -0.02114536, 0.07897805, 0.02339305, 0.08138859, -0.03255979],\n",
" dtype=float32),\n",
" array([-0.02713361, 0.09309322, 0.03805388, 0.03424142, -0.01113803,\n",
" 0.01667055, 0.07447569, -0.06668621, -0.03150655, 0.06387825,\n",
" 0.07785708, -0.02531729, 0.0723519 , 0.00347189, -0.0177357 ,\n",
" 0. , -0.03487411, -0.04851071, 0.06759044, -0.03533144,\n",
" 0.02401838, -0.03073125, -0.04450166, -0.02654641, -0.02827855,\n",
" 0. , -0.01567897, -0.0040624 , -0.02975524, -0.03281569,\n",
" -0.02176444, -0.03663797, -0.01836163, -0.01098274, -0.05059136,\n",
" -0.02361586, 0.06188901, 0.07352342, 0.05501923, -0.04410602,\n",
" -0.04001745, -0.05807052, 0.05206716, -0.01867674, -0.0089961 ,\n",
" -0.01759916, -0.01311922, -0.01491712, 0.05937983, 0.06336498],\n",
" dtype=float32),\n",
" array([ 0.05128358, -0.05128355], dtype=float32)])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_parameters(model)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"([array([[ 1.49011612e-06, -1.16565228e-02, 7.42741823e-02,\n",
" 1.84727460e-02, -8.54203105e-03, -1.07147992e-02,\n",
" -1.21345818e-02, 7.30311871e-02, -9.66983289e-03,\n",
" 6.97897077e-02, 3.34345400e-02, 8.74996185e-05,\n",
" 8.57096016e-02, 1.60385072e-02, 1.19651444e-02,\n",
" 2.49370933e-03, 8.43369067e-02, 1.64110810e-02,\n",
" 2.05034018e-02, -1.07030272e-02, 1.41942352e-02,\n",
" 8.08978826e-03, 3.21836621e-02, 5.65186143e-02,\n",
" -4.39171195e-02, 4.30863351e-03, -1.15206484e-02,\n",
" 2.10180618e-02, -7.25415349e-03, 3.23727727e-06,\n",
" -3.62303257e-02, 1.32649988e-02, 6.42685592e-03,\n",
" 1.92415118e-02, 3.42182815e-02, 1.04672834e-02,\n",
" 0.00000000e+00, 2.78882086e-02, 4.19589877e-02,\n",
" -2.18840521e-02, 9.14855003e-02, 4.81577218e-03,\n",
" -5.83276898e-03, 1.10103935e-02, -1.27351061e-02,\n",
" 1.04498863e-02, -4.18264568e-02, 2.71596760e-02,\n",
" 2.05704421e-02, -3.03956270e-02, 3.23033035e-02,\n",
" -4.21546996e-02, 4.45771776e-02, -3.91504467e-02,\n",
" -8.15168023e-02, 2.26370320e-02, -4.53699529e-02,\n",
" -4.22071815e-02, -1.31393373e-02, 1.32917091e-02,\n",
" 3.45049202e-02, -2.07801014e-02, -3.54883261e-03,\n",
" 1.25047117e-02, -6.88902661e-03, -3.12848352e-02,\n",
" 1.91906095e-02, 4.19499911e-02, -6.82501495e-02,\n",
" -1.53118074e-02],\n",
" [ 1.26473606e-06, 3.19551006e-02, -9.95661318e-03,\n",
" 1.19727850e-02, -1.79602206e-03, 1.53238177e-02,\n",
" 3.48728746e-02, -6.26268983e-03, 1.16313845e-02,\n",
" -1.24692582e-02, 7.92261958e-03, 1.70102119e-02,\n",
" -6.51185215e-03, 1.54842287e-02, -1.05841383e-02,\n",
" 6.68349117e-03, -5.42246550e-03, 2.38585621e-02,\n",
" -4.80242074e-03, 4.91370745e-02, 1.68562233e-02,\n",
" 1.78126954e-02, 1.90734789e-02, 1.65916681e-02,\n",
" 1.61990523e-03, -8.57291371e-03, 2.71531343e-02,\n",
" -7.27659464e-03, 4.15879637e-02, 2.78651714e-06,\n",
" -5.89160994e-03, -5.41770346e-02, 2.21652985e-02,\n",
" 2.85971761e-02, 2.64423043e-02, 2.02812403e-02,\n",
" 0.00000000e+00, 4.61198390e-03, 1.03888065e-02,\n",
" -1.25160664e-02, 6.73976541e-03, 7.16526806e-03,\n",
" 2.04948001e-02, 3.80229056e-02, 5.23998961e-03,\n",
" 1.47194266e-02, -3.79353762e-04, 1.02847666e-02,\n",
" 2.61461437e-02, -2.40194798e-02, 6.51396215e-02,\n",
" -3.63333039e-02, -5.51320203e-02, 6.67399764e-02,\n",
" 2.30744481e-04, -1.98918562e-02, -1.48451179e-02,\n",
" 2.16108412e-02, 2.36949362e-02, 1.34854689e-02,\n",
" 2.18952857e-02, -3.03417742e-02, 1.44592375e-02,\n",
" -5.22305071e-03, 2.16116756e-02, -1.71370506e-02,\n",
" -1.07840151e-02, 2.18297653e-02, -4.98308837e-02,\n",
" 6.17059022e-02],\n",
" [ 7.59959221e-07, 6.49914891e-03, 6.02702424e-03,\n",
" 4.44378480e-02, -3.70854139e-03, -4.58493829e-03,\n",
" -6.07848167e-03, 1.07879266e-02, -7.88053870e-03,\n",
" 5.81801683e-03, -5.74702024e-03, -1.39639974e-02,\n",
" -6.83215261e-03, 4.36457917e-02, 3.14892232e-02,\n",
" 7.88080785e-03, -8.65940750e-03, 2.13787258e-02,\n",
" -4.84359264e-03, 3.76113243e-02, -7.68777728e-03,\n",
" 3.44041884e-02, 2.17271540e-02, 3.49786282e-02,\n",
" -9.57245380e-03, 1.17718354e-02, 1.76328421e-03,\n",
" 2.47428566e-02, -1.89929008e-02, 9.83476639e-07,\n",
" -4.87717986e-03, 1.81090236e-02, -1.68472528e-03,\n",
" 1.75802708e-02, 1.03006195e-02, 2.29674578e-03,\n",
" 0.00000000e+00, 1.71143487e-02, -5.93899190e-03,\n",
" -8.52180272e-03, 2.21078098e-03, -6.65974617e-03,\n",
" -5.27688861e-03, 1.78252906e-03, 1.92168504e-02,\n",
" 3.20638865e-02, 3.12381685e-02, -9.20593739e-04,\n",