%load_ext autoreload
%autoreload 2
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
from xgboost import XGBRegressor
from lightgbm import LGBMRegressor
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error as mse
from scipy.stats import entropy
import warnings
from causalml.inference.meta import LRSRegressor
from causalml.inference.meta import XGBTRegressor, MLPTRegressor
from causalml.inference.meta import BaseXRegressor, BaseRRegressor, BaseSRegressor, BaseTRegressor
from causalml.inference.nn import DragonNet
from causalml.match import NearestNeighborMatch, MatchOptimizer, create_table_one
from causalml.propensity import ElasticNetPropensityModel
from causalml.dataset.regression import *
from causalml.metrics import *
import os, sys
%matplotlib inline
warnings.filterwarnings('ignore')
plt.style.use('fivethirtyeight')
sns.set_palette('Paired')
plt.rcParams['figure.figsize'] = (12,8)
Hill introduced a semi-synthetic dataset constructed from the Infant Health and Development Program (IHDP). This dataset is based on a randomized experiment investigating the effect of home visits by specialists on future cognitive scores. The data has 747 observations (rows). The IHDP simulation is considered the de-facto standard benchmark for neural network treatment effect estimation methods.
The original paper uses 1000 realizations from the NCPI package, but for illustration purposes, we use 1 dataset (realization) as an example below.
df = pd.read_csv(f'data/ihdp_npci_3.csv', header=None)
cols = ["treatment", "y_factual", "y_cfactual", "mu0", "mu1"] + [f'x{i}' for i in range(1,26)]
df.columns = cols
df.shape
(747, 30)
df.head()
| treatment | y_factual | y_cfactual | mu0 | mu1 | x1 | x2 | x3 | x4 | x5 | ... | x16 | x17 | x18 | x19 | x20 | x21 | x22 | x23 | x24 | x25 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 5.931652 | 3.500591 | 2.253801 | 7.136441 | -0.528603 | -0.343455 | 1.128554 | 0.161703 | -0.316603 | ... | 1 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
| 1 | 0 | 2.175966 | 5.952101 | 1.257592 | 6.553022 | -1.736945 | -1.802002 | 0.383828 | 2.244320 | -0.629189 | ... | 1 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
| 2 | 0 | 2.180294 | 7.175734 | 2.384100 | 7.192645 | -0.807451 | -0.202946 | -0.360898 | -0.879606 | 0.808706 | ... | 1 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
| 3 | 0 | 3.587662 | 7.787537 | 4.009365 | 7.712456 | 0.390083 | 0.596582 | -1.850350 | -0.879606 | -0.004017 | ... | 1 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
| 4 | 0 | 2.372618 | 5.461871 | 2.481631 | 7.232739 | -1.045229 | -0.602710 | 0.011465 | 0.161703 | 0.683672 | ... | 1 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
5 rows × 30 columns
pd.Series(df['treatment']).value_counts(normalize=True)
0 0.813922 1 0.186078 Name: treatment, dtype: float64
X = df.loc[:,'x1':]
treatment = df['treatment']
y = df['y_factual']
tau = df.apply(lambda d: d['y_factual'] - d['y_cfactual'] if d['treatment']==1
else d['y_cfactual'] - d['y_factual'],
axis=1)
# p_model = LogisticRegressionCV(penalty='elasticnet', solver='saga', l1_ratios=np.linspace(0,1,5),
# cv=StratifiedKFold(n_splits=4, shuffle=True))
# p_model.fit(X, treatment)
# p = p_model.predict_proba(X)[:, 1]
p_model = ElasticNetPropensityModel()
p = p_model.fit_predict(X, treatment)
s_learner = BaseSRegressor(LGBMRegressor())
s_ate = s_learner.estimate_ate(X, treatment, y)[0]
s_ite = s_learner.fit_predict(X, treatment, y)
t_learner = BaseTRegressor(LGBMRegressor())
t_ate = t_learner.estimate_ate(X, treatment, y)[0][0]
t_ite = t_learner.fit_predict(X, treatment, y)
x_learner = BaseXRegressor(LGBMRegressor())
x_ate = x_learner.estimate_ate(X, treatment, y, p)[0][0]
x_ite = x_learner.fit_predict(X, treatment, y, p)
r_learner = BaseRRegressor(LGBMRegressor())
r_ate = r_learner.estimate_ate(X, treatment, y, p)[0][0]
r_ite = r_learner.fit_predict(X, treatment, y, p)
dragon = DragonNet(neurons_per_layer=200, targeted_reg=True)
dragon_ite = dragon.fit_predict(X, treatment, y, return_components=False)
dragon_ate = dragon_ite.mean()
Train on 597 samples, validate on 150 samples Epoch 1/30 597/597 [==============================] - 2s 3ms/step - loss: 1270.6818 - regression_loss: 614.8568 - binary_classification_loss: 37.4892 - treatment_accuracy: 0.8509 - track_epsilon: 0.0425 - val_loss: 239.1906 - val_regression_loss: 97.6509 - val_binary_classification_loss: 40.0911 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0418 Epoch 2/30 597/597 [==============================] - 0s 83us/step - loss: 296.8994 - regression_loss: 127.7073 - binary_classification_loss: 29.0478 - treatment_accuracy: 0.8526 - track_epsilon: 0.0421 - val_loss: 239.8115 - val_regression_loss: 94.8186 - val_binary_classification_loss: 43.9986 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0405 Epoch 3/30 597/597 [==============================] - 0s 97us/step - loss: 255.4919 - regression_loss: 113.5166 - binary_classification_loss: 27.9721 - treatment_accuracy: 0.8526 - track_epsilon: 0.0396 - val_loss: 259.4397 - val_regression_loss: 108.6940 - val_binary_classification_loss: 43.2644 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0396 Train on 597 samples, validate on 150 samples Epoch 1/300 597/597 [==============================] - 1s 2ms/step - loss: 211.4195 - regression_loss: 90.2754 - binary_classification_loss: 27.2481 - treatment_accuracy: 0.8526 - track_epsilon: 0.0427 - val_loss: 219.0321 - val_regression_loss: 83.5967 - val_binary_classification_loss: 45.4129 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0484 Epoch 2/300 597/597 [==============================] - 0s 78us/step - loss: 204.8274 - regression_loss: 86.0335 - binary_classification_loss: 27.2709 - treatment_accuracy: 0.8526 - track_epsilon: 0.0478 - val_loss: 208.0823 - val_regression_loss: 78.5176 - val_binary_classification_loss: 45.1409 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0387 Epoch 3/300 597/597 [==============================] - 0s 88us/step - loss: 194.1909 - regression_loss: 81.2670 - binary_classification_loss: 27.0086 - treatment_accuracy: 0.8526 - track_epsilon: 0.0308 - val_loss: 205.8695 - val_regression_loss: 78.4999 - val_binary_classification_loss: 44.7895 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0213 Epoch 4/300 597/597 [==============================] - 0s 87us/step - loss: 189.5814 - regression_loss: 79.2617 - binary_classification_loss: 26.5288 - treatment_accuracy: 0.8526 - track_epsilon: 0.0173 - val_loss: 201.7048 - val_regression_loss: 75.6974 - val_binary_classification_loss: 45.4131 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0156 Epoch 5/300 597/597 [==============================] - 0s 84us/step - loss: 183.4325 - regression_loss: 76.0126 - binary_classification_loss: 26.6035 - treatment_accuracy: 0.8526 - track_epsilon: 0.0118 - val_loss: 202.8189 - val_regression_loss: 75.9546 - val_binary_classification_loss: 45.9228 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0086 Epoch 6/300 597/597 [==============================] - 0s 96us/step - loss: 182.2276 - regression_loss: 75.4041 - binary_classification_loss: 26.6296 - treatment_accuracy: 0.8526 - track_epsilon: 0.0053 - val_loss: 205.3384 - val_regression_loss: 77.1694 - val_binary_classification_loss: 46.2288 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0012 Epoch 7/300 597/597 [==============================] - 0s 83us/step - loss: 180.2415 - regression_loss: 74.4598 - binary_classification_loss: 26.5372 - treatment_accuracy: 0.8526 - track_epsilon: 0.0021 - val_loss: 201.6832 - val_regression_loss: 75.4064 - val_binary_classification_loss: 46.1580 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0029 Epoch 8/300 597/597 [==============================] - 0s 85us/step - loss: 174.4812 - regression_loss: 71.7626 - binary_classification_loss: 26.1734 - treatment_accuracy: 0.8526 - track_epsilon: 0.0042 - val_loss: 201.3234 - val_regression_loss: 75.2171 - val_binary_classification_loss: 46.2252 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0049 Epoch 9/300 597/597 [==============================] - 0s 89us/step - loss: 171.0054 - regression_loss: 70.0748 - binary_classification_loss: 26.1567 - treatment_accuracy: 0.8526 - track_epsilon: 0.0066 - val_loss: 202.3993 - val_regression_loss: 75.6317 - val_binary_classification_loss: 46.5645 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0082 Epoch 10/300 597/597 [==============================] - 0s 91us/step - loss: 170.0682 - regression_loss: 69.5563 - binary_classification_loss: 26.2435 - treatment_accuracy: 0.8526 - track_epsilon: 0.0096 - val_loss: 199.6441 - val_regression_loss: 74.2625 - val_binary_classification_loss: 46.4598 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0109 Epoch 11/300 597/597 [==============================] - 0s 97us/step - loss: 167.0560 - regression_loss: 68.0373 - binary_classification_loss: 26.2824 - treatment_accuracy: 0.8526 - track_epsilon: 0.0118 - val_loss: 198.9716 - val_regression_loss: 73.9292 - val_binary_classification_loss: 46.5013 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0115 Epoch 12/300 597/597 [==============================] - 0s 96us/step - loss: 165.4987 - regression_loss: 67.3859 - binary_classification_loss: 25.9894 - treatment_accuracy: 0.8526 - track_epsilon: 0.0119 - val_loss: 197.8179 - val_regression_loss: 73.4053 - val_binary_classification_loss: 46.4631 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0103 Epoch 13/300 597/597 [==============================] - 0s 76us/step - loss: 166.0216 - regression_loss: 67.8836 - binary_classification_loss: 25.6995 - treatment_accuracy: 0.8526 - track_epsilon: 0.0100 - val_loss: 198.6716 - val_regression_loss: 73.6842 - val_binary_classification_loss: 46.4583 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0145 Epoch 14/300 597/597 [==============================] - 0s 77us/step - loss: 167.4158 - regression_loss: 68.3341 - binary_classification_loss: 26.0148 - treatment_accuracy: 0.8526 - track_epsilon: 0.0159 - val_loss: 196.3751 - val_regression_loss: 72.4750 - val_binary_classification_loss: 46.5908 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0161 Epoch 15/300 597/597 [==============================] - 0s 99us/step - loss: 162.3269 - regression_loss: 65.8252 - binary_classification_loss: 25.9338 - treatment_accuracy: 0.8526 - track_epsilon: 0.0133 - val_loss: 195.3277 - val_regression_loss: 72.2033 - val_binary_classification_loss: 46.3533 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0114 Epoch 16/300 597/597 [==============================] - 0s 87us/step - loss: 156.9714 - regression_loss: 63.4403 - binary_classification_loss: 25.4225 - treatment_accuracy: 0.8526 - track_epsilon: 0.0130 - val_loss: 194.1708 - val_regression_loss: 71.3998 - val_binary_classification_loss: 46.5003 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0149 Epoch 17/300 597/597 [==============================] - 0s 86us/step - loss: 159.4497 - regression_loss: 64.2904 - binary_classification_loss: 26.0263 - treatment_accuracy: 0.8526 - track_epsilon: 0.0151 - val_loss: 192.6951 - val_regression_loss: 70.8486 - val_binary_classification_loss: 46.3501 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0115 Epoch 18/300 597/597 [==============================] - 0s 86us/step - loss: 159.7066 - regression_loss: 64.6457 - binary_classification_loss: 25.7649 - treatment_accuracy: 0.8526 - track_epsilon: 0.0098 - val_loss: 192.5754 - val_regression_loss: 70.8909 - val_binary_classification_loss: 46.1461 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0097 Epoch 19/300 597/597 [==============================] - 0s 80us/step - loss: 157.0761 - regression_loss: 63.2936 - binary_classification_loss: 25.8146 - treatment_accuracy: 0.8526 - track_epsilon: 0.0080 - val_loss: 191.2655 - val_regression_loss: 70.1602 - val_binary_classification_loss: 46.1613 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0117 Epoch 20/300 597/597 [==============================] - 0s 81us/step - loss: 155.1142 - regression_loss: 62.3842 - binary_classification_loss: 25.4675 - treatment_accuracy: 0.8526 - track_epsilon: 0.0135 - val_loss: 189.5332 - val_regression_loss: 69.2598 - val_binary_classification_loss: 46.0150 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0122 Epoch 21/300 597/597 [==============================] - 0s 78us/step - loss: 154.5792 - regression_loss: 62.0438 - binary_classification_loss: 25.7474 - treatment_accuracy: 0.8526 - track_epsilon: 0.0108 - val_loss: 191.2697 - val_regression_loss: 70.2473 - val_binary_classification_loss: 46.1159 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0086 Epoch 22/300 597/597 [==============================] - 0s 67us/step - loss: 153.5248 - regression_loss: 61.7347 - binary_classification_loss: 25.3457 - treatment_accuracy: 0.8526 - track_epsilon: 0.0080 - val_loss: 189.2342 - val_regression_loss: 69.2679 - val_binary_classification_loss: 46.0097 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0079 Epoch 23/300 597/597 [==============================] - 0s 74us/step - loss: 155.1560 - regression_loss: 62.2260 - binary_classification_loss: 25.9918 - treatment_accuracy: 0.8526 - track_epsilon: 0.0065 - val_loss: 187.2720 - val_regression_loss: 68.3552 - val_binary_classification_loss: 45.7725 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0089 Epoch 24/300 597/597 [==============================] - 0s 72us/step - loss: 152.4445 - regression_loss: 61.0552 - binary_classification_loss: 25.4562 - treatment_accuracy: 0.8526 - track_epsilon: 0.0107 - val_loss: 193.3582 - val_regression_loss: 71.4724 - val_binary_classification_loss: 45.5549 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0080 Epoch 25/300 597/597 [==============================] - 0s 79us/step - loss: 153.9410 - regression_loss: 61.9316 - binary_classification_loss: 25.3986 - treatment_accuracy: 0.8526 - track_epsilon: 0.0067 - val_loss: 185.9885 - val_regression_loss: 67.8059 - val_binary_classification_loss: 45.6309 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0066 Epoch 26/300 597/597 [==============================] - 0s 82us/step - loss: 152.2284 - regression_loss: 60.9298 - binary_classification_loss: 25.6594 - treatment_accuracy: 0.8526 - track_epsilon: 0.0064 - val_loss: 186.9394 - val_regression_loss: 68.2406 - val_binary_classification_loss: 45.6855 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0080 Epoch 27/300 597/597 [==============================] - 0s 76us/step - loss: 151.0331 - regression_loss: 60.5148 - binary_classification_loss: 25.2446 - treatment_accuracy: 0.8526 - track_epsilon: 0.0091 - val_loss: 186.0575 - val_regression_loss: 67.8268 - val_binary_classification_loss: 45.5729 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0080 Epoch 28/300 597/597 [==============================] - 0s 73us/step - loss: 152.1600 - regression_loss: 60.9849 - binary_classification_loss: 25.4377 - treatment_accuracy: 0.8526 - track_epsilon: 0.0060 - val_loss: 186.1113 - val_regression_loss: 68.0358 - val_binary_classification_loss: 45.3451 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0039 Epoch 29/300 597/597 [==============================] - 0s 69us/step - loss: 150.9097 - regression_loss: 60.4212 - binary_classification_loss: 25.3437 - treatment_accuracy: 0.8526 - track_epsilon: 0.0047 - val_loss: 184.9033 - val_regression_loss: 67.3800 - val_binary_classification_loss: 45.3508 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0065 Epoch 30/300 597/597 [==============================] - 0s 68us/step - loss: 150.0184 - regression_loss: 59.7821 - binary_classification_loss: 25.6023 - treatment_accuracy: 0.8526 - track_epsilon: 0.0071 - val_loss: 196.9308 - val_regression_loss: 73.5301 - val_binary_classification_loss: 44.9814 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0060 Epoch 31/300 597/597 [==============================] - 0s 67us/step - loss: 154.4520 - regression_loss: 62.0705 - binary_classification_loss: 25.5867 - treatment_accuracy: 0.8526 - track_epsilon: 0.0048 - val_loss: 184.0948 - val_regression_loss: 67.0141 - val_binary_classification_loss: 45.3803 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0036 Epoch 32/300 597/597 [==============================] - 0s 69us/step - loss: 148.9450 - regression_loss: 59.4126 - binary_classification_loss: 25.3471 - treatment_accuracy: 0.8526 - track_epsilon: 0.0051 - val_loss: 184.9047 - val_regression_loss: 67.3940 - val_binary_classification_loss: 45.2953 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0065 Epoch 33/300 597/597 [==============================] - 0s 62us/step - loss: 147.7056 - regression_loss: 58.6979 - binary_classification_loss: 25.5380 - treatment_accuracy: 0.8526 - track_epsilon: 0.0052 - val_loss: 183.7653 - val_regression_loss: 66.8886 - val_binary_classification_loss: 45.2760 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0043 Epoch 34/300 597/597 [==============================] - 0s 64us/step - loss: 148.4505 - regression_loss: 59.2592 - binary_classification_loss: 25.1531 - treatment_accuracy: 0.8526 - track_epsilon: 0.0045 - val_loss: 184.1559 - val_regression_loss: 67.0669 - val_binary_classification_loss: 45.2174 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0061 Epoch 35/300 597/597 [==============================] - 0s 65us/step - loss: 148.2367 - regression_loss: 58.9319 - binary_classification_loss: 25.5943 - treatment_accuracy: 0.8526 - track_epsilon: 0.0054 - val_loss: 185.0724 - val_regression_loss: 67.6386 - val_binary_classification_loss: 45.0281 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0042 Epoch 36/300 597/597 [==============================] - 0s 70us/step - loss: 150.4369 - regression_loss: 60.2207 - binary_classification_loss: 25.2545 - treatment_accuracy: 0.8526 - track_epsilon: 0.0037 - val_loss: 183.1421 - val_regression_loss: 66.6007 - val_binary_classification_loss: 45.2055 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0035 Epoch 37/300 597/597 [==============================] - 0s 79us/step - loss: 147.9177 - regression_loss: 58.9068 - binary_classification_loss: 25.3180 - treatment_accuracy: 0.8526 - track_epsilon: 0.0045 - val_loss: 187.0380 - val_regression_loss: 68.6611 - val_binary_classification_loss: 44.9672 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0030 Epoch 38/300 597/597 [==============================] - 0s 69us/step - loss: 149.4905 - regression_loss: 59.7513 - binary_classification_loss: 25.2271 - treatment_accuracy: 0.8526 - track_epsilon: 0.0037 - val_loss: 185.0969 - val_regression_loss: 67.6783 - val_binary_classification_loss: 45.0082 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0016 Epoch 00038: ReduceLROnPlateau reducing learning rate to 4.999999873689376e-06. Epoch 39/300 597/597 [==============================] - 0s 69us/step - loss: 150.4060 - regression_loss: 59.9812 - binary_classification_loss: 25.7022 - treatment_accuracy: 0.8526 - track_epsilon: 0.0020 - val_loss: 184.8061 - val_regression_loss: 67.4567 - val_binary_classification_loss: 45.1825 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0019 Epoch 40/300 597/597 [==============================] - 0s 68us/step - loss: 147.8911 - regression_loss: 58.7708 - binary_classification_loss: 25.5944 - treatment_accuracy: 0.8526 - track_epsilon: 0.0017 - val_loss: 184.7605 - val_regression_loss: 67.5390 - val_binary_classification_loss: 44.9553 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0020 Epoch 41/300 597/597 [==============================] - 0s 68us/step - loss: 147.4302 - regression_loss: 58.7775 - binary_classification_loss: 25.1217 - treatment_accuracy: 0.8526 - track_epsilon: 0.0030 - val_loss: 183.1474 - val_regression_loss: 66.6912 - val_binary_classification_loss: 45.0185 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0036 Epoch 42/300 597/597 [==============================] - 0s 71us/step - loss: 143.9032 - regression_loss: 57.0180 - binary_classification_loss: 25.1108 - treatment_accuracy: 0.8526 - track_epsilon: 0.0041 - val_loss: 183.1681 - val_regression_loss: 66.6652 - val_binary_classification_loss: 45.0685 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0045 Epoch 43/300 597/597 [==============================] - 0s 68us/step - loss: 145.9709 - regression_loss: 57.8910 - binary_classification_loss: 25.4062 - treatment_accuracy: 0.8526 - track_epsilon: 0.0033 - val_loss: 182.8127 - val_regression_loss: 66.5287 - val_binary_classification_loss: 45.0264 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0028 Epoch 44/300 597/597 [==============================] - 0s 70us/step - loss: 147.1118 - regression_loss: 58.4151 - binary_classification_loss: 25.5357 - treatment_accuracy: 0.8526 - track_epsilon: 0.0023 - val_loss: 183.9478 - val_regression_loss: 67.0817 - val_binary_classification_loss: 45.0465 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0033 Epoch 45/300 597/597 [==============================] - 0s 69us/step - loss: 146.6344 - regression_loss: 58.2201 - binary_classification_loss: 25.4157 - treatment_accuracy: 0.8526 - track_epsilon: 0.0035 - val_loss: 182.6564 - val_regression_loss: 66.4863 - val_binary_classification_loss: 44.9337 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0032 Epoch 46/300 597/597 [==============================] - 0s 73us/step - loss: 145.8090 - regression_loss: 57.9971 - binary_classification_loss: 25.0709 - treatment_accuracy: 0.8526 - track_epsilon: 0.0028 - val_loss: 185.5000 - val_regression_loss: 67.8478 - val_binary_classification_loss: 45.0701 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0031 Epoch 47/300 597/597 [==============================] - 0s 95us/step - loss: 149.4340 - regression_loss: 59.6320 - binary_classification_loss: 25.4013 - treatment_accuracy: 0.8526 - track_epsilon: 0.0026 - val_loss: 190.2017 - val_regression_loss: 70.3956 - val_binary_classification_loss: 44.6640 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0016 Epoch 00047: ReduceLROnPlateau reducing learning rate to 2.499999936844688e-06. Epoch 48/300 597/597 [==============================] - 0s 77us/step - loss: 149.7208 - regression_loss: 59.8646 - binary_classification_loss: 25.2522 - treatment_accuracy: 0.8526 - track_epsilon: 0.0023 - val_loss: 183.3720 - val_regression_loss: 66.8341 - val_binary_classification_loss: 44.9590 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0031 Epoch 49/300 597/597 [==============================] - 0s 70us/step - loss: 146.4089 - regression_loss: 58.2020 - binary_classification_loss: 25.2272 - treatment_accuracy: 0.8526 - track_epsilon: 0.0028 - val_loss: 181.9638 - val_regression_loss: 66.2070 - val_binary_classification_loss: 44.8110 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0021 Epoch 50/300 597/597 [==============================] - 0s 70us/step - loss: 145.6402 - regression_loss: 57.7422 - binary_classification_loss: 25.4097 - treatment_accuracy: 0.8526 - track_epsilon: 0.0025 - val_loss: 181.8525 - val_regression_loss: 66.1349 - val_binary_classification_loss: 44.8418 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0024 Epoch 51/300 597/597 [==============================] - 0s 73us/step - loss: 143.4757 - regression_loss: 56.7868 - binary_classification_loss: 25.1549 - treatment_accuracy: 0.8526 - track_epsilon: 0.0021 - val_loss: 182.7604 - val_regression_loss: 66.5569 - val_binary_classification_loss: 44.9187 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0021 Epoch 52/300 597/597 [==============================] - 0s 70us/step - loss: 146.1171 - regression_loss: 57.9757 - binary_classification_loss: 25.4103 - treatment_accuracy: 0.8526 - track_epsilon: 0.0019 - val_loss: 182.4569 - val_regression_loss: 66.4359 - val_binary_classification_loss: 44.8548 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0021 Epoch 53/300 597/597 [==============================] - 0s 71us/step - loss: 146.4711 - regression_loss: 58.1842 - binary_classification_loss: 25.3556 - treatment_accuracy: 0.8526 - track_epsilon: 0.0020 - val_loss: 182.3928 - val_regression_loss: 66.4065 - val_binary_classification_loss: 44.8446 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0024 Epoch 54/300 597/597 [==============================] - 0s 74us/step - loss: 145.7055 - regression_loss: 57.7771 - binary_classification_loss: 25.3857 - treatment_accuracy: 0.8526 - track_epsilon: 0.0025 - val_loss: 182.2147 - val_regression_loss: 66.3179 - val_binary_classification_loss: 44.8379 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0024 Epoch 55/300 597/597 [==============================] - 0s 72us/step - loss: 144.4526 - regression_loss: 57.3121 - binary_classification_loss: 25.0753 - treatment_accuracy: 0.8526 - track_epsilon: 0.0024 - val_loss: 182.0351 - val_regression_loss: 66.2530 - val_binary_classification_loss: 44.7874 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0024 Epoch 56/300 597/597 [==============================] - 0s 75us/step - loss: 144.1982 - regression_loss: 57.1205 - binary_classification_loss: 25.2036 - treatment_accuracy: 0.8526 - track_epsilon: 0.0024 - val_loss: 182.3136 - val_regression_loss: 66.3662 - val_binary_classification_loss: 44.8441 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0025 Epoch 00056: ReduceLROnPlateau reducing learning rate to 1.249999968422344e-06. Epoch 57/300 597/597 [==============================] - 0s 73us/step - loss: 143.8369 - regression_loss: 56.8825 - binary_classification_loss: 25.3167 - treatment_accuracy: 0.8526 - track_epsilon: 0.0025 - val_loss: 182.3594 - val_regression_loss: 66.4025 - val_binary_classification_loss: 44.8163 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0024 Epoch 58/300 597/597 [==============================] - 0s 73us/step - loss: 144.2899 - regression_loss: 57.0475 - binary_classification_loss: 25.4382 - treatment_accuracy: 0.8526 - track_epsilon: 0.0024 - val_loss: 182.5412 - val_regression_loss: 66.5005 - val_binary_classification_loss: 44.8003 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0022 Epoch 59/300 597/597 [==============================] - 0s 71us/step - loss: 144.2426 - regression_loss: 57.2677 - binary_classification_loss: 24.9548 - treatment_accuracy: 0.8526 - track_epsilon: 0.0018 - val_loss: 182.6658 - val_regression_loss: 66.5357 - val_binary_classification_loss: 44.8637 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0018 Epoch 60/300 597/597 [==============================] - 0s 78us/step - loss: 143.0137 - regression_loss: 56.4687 - binary_classification_loss: 25.3253 - treatment_accuracy: 0.8526 - track_epsilon: 0.0020 - val_loss: 182.4833 - val_regression_loss: 66.4569 - val_binary_classification_loss: 44.8323 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0023 Epoch 61/300 597/597 [==============================] - 0s 79us/step - loss: 144.5388 - regression_loss: 57.2944 - binary_classification_loss: 25.1843 - treatment_accuracy: 0.8526 - track_epsilon: 0.0025 - val_loss: 182.4334 - val_regression_loss: 66.4388 - val_binary_classification_loss: 44.8183 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0024 Epoch 62/300 597/597 [==============================] - 0s 85us/step - loss: 143.6628 - regression_loss: 56.9668 - binary_classification_loss: 24.9783 - treatment_accuracy: 0.8526 - track_epsilon: 0.0023 - val_loss: 182.3477 - val_regression_loss: 66.4035 - val_binary_classification_loss: 44.8031 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0023 Epoch 63/300 597/597 [==============================] - 0s 77us/step - loss: 143.7592 - regression_loss: 56.7775 - binary_classification_loss: 25.4567 - treatment_accuracy: 0.8526 - track_epsilon: 0.0023 - val_loss: 182.3529 - val_regression_loss: 66.4033 - val_binary_classification_loss: 44.8094 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0023 Epoch 64/300 597/597 [==============================] - 0s 78us/step - loss: 144.4561 - regression_loss: 57.2867 - binary_classification_loss: 25.1253 - treatment_accuracy: 0.8526 - track_epsilon: 0.0023 - val_loss: 182.3060 - val_regression_loss: 66.3758 - val_binary_classification_loss: 44.8172 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0021 Epoch 65/300 597/597 [==============================] - 0s 70us/step - loss: 144.3765 - regression_loss: 57.3234 - binary_classification_loss: 24.9777 - treatment_accuracy: 0.8526 - track_epsilon: 0.0016 - val_loss: 182.4476 - val_regression_loss: 66.4529 - val_binary_classification_loss: 44.8072 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0016 Epoch 00065: ReduceLROnPlateau reducing learning rate to 6.24999984211172e-07. Epoch 66/300 597/597 [==============================] - 0s 71us/step - loss: 144.8187 - regression_loss: 57.5510 - binary_classification_loss: 24.9592 - treatment_accuracy: 0.8526 - track_epsilon: 0.0019 - val_loss: 182.6972 - val_regression_loss: 66.5565 - val_binary_classification_loss: 44.8517 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0022 Epoch 67/300 597/597 [==============================] - 0s 76us/step - loss: 144.8041 - regression_loss: 57.4736 - binary_classification_loss: 25.0938 - treatment_accuracy: 0.8526 - track_epsilon: 0.0023 - val_loss: 182.5234 - val_regression_loss: 66.4864 - val_binary_classification_loss: 44.8136 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0023 Epoch 68/300 597/597 [==============================] - 0s 70us/step - loss: 143.8252 - regression_loss: 56.9214 - binary_classification_loss: 25.2248 - treatment_accuracy: 0.8526 - track_epsilon: 0.0021 - val_loss: 182.4574 - val_regression_loss: 66.4647 - val_binary_classification_loss: 44.7915 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0020 Epoch 69/300 597/597 [==============================] - 0s 75us/step - loss: 143.1978 - regression_loss: 56.8314 - binary_classification_loss: 24.7805 - treatment_accuracy: 0.8526 - track_epsilon: 0.0021 - val_loss: 182.3775 - val_regression_loss: 66.4219 - val_binary_classification_loss: 44.7960 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0021 Epoch 70/300 597/597 [==============================] - 0s 70us/step - loss: 144.5244 - regression_loss: 57.2129 - binary_classification_loss: 25.3473 - treatment_accuracy: 0.8526 - track_epsilon: 0.0020 - val_loss: 182.3961 - val_regression_loss: 66.4298 - val_binary_classification_loss: 44.8000 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0020 Epoch 00070: ReduceLROnPlateau reducing learning rate to 3.12499992105586e-07. Epoch 71/300 597/597 [==============================] - 0s 76us/step - loss: 142.9321 - regression_loss: 56.5138 - binary_classification_loss: 25.1521 - treatment_accuracy: 0.8526 - track_epsilon: 0.0021 - val_loss: 182.4254 - val_regression_loss: 66.4404 - val_binary_classification_loss: 44.8086 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0021 Epoch 72/300 597/597 [==============================] - 0s 80us/step - loss: 144.3344 - regression_loss: 57.1887 - binary_classification_loss: 25.2010 - treatment_accuracy: 0.8526 - track_epsilon: 0.0022 - val_loss: 182.4085 - val_regression_loss: 66.4308 - val_binary_classification_loss: 44.8096 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0022 Epoch 73/300 597/597 [==============================] - 0s 82us/step - loss: 143.8520 - regression_loss: 57.0468 - binary_classification_loss: 25.0074 - treatment_accuracy: 0.8526 - track_epsilon: 0.0021 - val_loss: 182.4823 - val_regression_loss: 66.4695 - val_binary_classification_loss: 44.8077 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0021 Epoch 74/300 597/597 [==============================] - 0s 88us/step - loss: 142.4526 - regression_loss: 56.3840 - binary_classification_loss: 24.9286 - treatment_accuracy: 0.8526 - track_epsilon: 0.0021 - val_loss: 182.4594 - val_regression_loss: 66.4591 - val_binary_classification_loss: 44.8052 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0021 Epoch 75/300 597/597 [==============================] - 0s 93us/step - loss: 142.7227 - regression_loss: 56.3094 - binary_classification_loss: 25.3561 - treatment_accuracy: 0.8526 - track_epsilon: 0.0022 - val_loss: 182.4612 - val_regression_loss: 66.4628 - val_binary_classification_loss: 44.7980 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0022 Epoch 76/300 597/597 [==============================] - 0s 86us/step - loss: 142.5488 - regression_loss: 56.2381 - binary_classification_loss: 25.3205 - treatment_accuracy: 0.8526 - track_epsilon: 0.0021 - val_loss: 182.4939 - val_regression_loss: 66.4762 - val_binary_classification_loss: 44.8060 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0021 Epoch 77/300 597/597 [==============================] - 0s 83us/step - loss: 143.2987 - regression_loss: 56.7939 - binary_classification_loss: 24.9629 - treatment_accuracy: 0.8526 - track_epsilon: 0.0021 - val_loss: 182.4439 - val_regression_loss: 66.4569 - val_binary_classification_loss: 44.7929 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0021 Epoch 78/300 597/597 [==============================] - 0s 84us/step - loss: 143.1639 - regression_loss: 56.5955 - binary_classification_loss: 25.2236 - treatment_accuracy: 0.8526 - track_epsilon: 0.0021 - val_loss: 182.4254 - val_regression_loss: 66.4463 - val_binary_classification_loss: 44.7962 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0021 Epoch 79/300 597/597 [==============================] - 0s 79us/step - loss: 143.0552 - regression_loss: 56.6011 - binary_classification_loss: 25.1026 - treatment_accuracy: 0.8526 - track_epsilon: 0.0021 - val_loss: 182.3916 - val_regression_loss: 66.4269 - val_binary_classification_loss: 44.8006 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0022 Epoch 00079: ReduceLROnPlateau reducing learning rate to 1.56249996052793e-07. Epoch 80/300 597/597 [==============================] - 0s 77us/step - loss: 143.4267 - regression_loss: 56.6880 - binary_classification_loss: 25.2984 - treatment_accuracy: 0.8526 - track_epsilon: 0.0021 - val_loss: 182.4181 - val_regression_loss: 66.4422 - val_binary_classification_loss: 44.7968 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0021 Epoch 81/300 597/597 [==============================] - 0s 84us/step - loss: 143.5535 - regression_loss: 56.7715 - binary_classification_loss: 25.2593 - treatment_accuracy: 0.8526 - track_epsilon: 0.0021 - val_loss: 182.4084 - val_regression_loss: 66.4378 - val_binary_classification_loss: 44.7963 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0021 Epoch 82/300 597/597 [==============================] - 0s 72us/step - loss: 143.1737 - regression_loss: 56.5954 - binary_classification_loss: 25.2270 - treatment_accuracy: 0.8526 - track_epsilon: 0.0021 - val_loss: 182.4203 - val_regression_loss: 66.4446 - val_binary_classification_loss: 44.7944 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0021 Epoch 83/300 597/597 [==============================] - 0s 69us/step - loss: 143.4288 - regression_loss: 56.8222 - binary_classification_loss: 25.0314 - treatment_accuracy: 0.8526 - track_epsilon: 0.0021 - val_loss: 182.4125 - val_regression_loss: 66.4380 - val_binary_classification_loss: 44.8005 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0021 Epoch 84/300 597/597 [==============================] - 0s 69us/step - loss: 144.5514 - regression_loss: 57.2933 - binary_classification_loss: 25.2142 - treatment_accuracy: 0.8526 - track_epsilon: 0.0020 - val_loss: 182.3979 - val_regression_loss: 66.4331 - val_binary_classification_loss: 44.7956 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0021 Epoch 00084: ReduceLROnPlateau reducing learning rate to 7.81249980263965e-08. Epoch 85/300 597/597 [==============================] - 0s 68us/step - loss: 143.5767 - regression_loss: 56.7675 - binary_classification_loss: 25.2862 - treatment_accuracy: 0.8526 - track_epsilon: 0.0021 - val_loss: 182.4303 - val_regression_loss: 66.4495 - val_binary_classification_loss: 44.7950 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0021 Epoch 86/300 597/597 [==============================] - 0s 70us/step - loss: 144.4275 - regression_loss: 57.1787 - binary_classification_loss: 25.3191 - treatment_accuracy: 0.8526 - track_epsilon: 0.0021 - val_loss: 182.4286 - val_regression_loss: 66.4496 - val_binary_classification_loss: 44.7930 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0021 Epoch 87/300 597/597 [==============================] - 0s 66us/step - loss: 143.0436 - regression_loss: 56.5908 - binary_classification_loss: 25.1099 - treatment_accuracy: 0.8526 - track_epsilon: 0.0021 - val_loss: 182.4225 - val_regression_loss: 66.4457 - val_binary_classification_loss: 44.7947 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0021 Epoch 88/300 597/597 [==============================] - 0s 69us/step - loss: 144.7102 - regression_loss: 57.4436 - binary_classification_loss: 25.0714 - treatment_accuracy: 0.8526 - track_epsilon: 0.0021 - val_loss: 182.4230 - val_regression_loss: 66.4466 - val_binary_classification_loss: 44.7933 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0021 Epoch 89/300 597/597 [==============================] - 0s 71us/step - loss: 142.3754 - regression_loss: 56.2528 - binary_classification_loss: 25.1221 - treatment_accuracy: 0.8526 - track_epsilon: 0.0021 - val_loss: 182.4270 - val_regression_loss: 66.4474 - val_binary_classification_loss: 44.7957 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0021 Epoch 90/300 597/597 [==============================] - 0s 70us/step - loss: 143.2687 - regression_loss: 56.5889 - binary_classification_loss: 25.3412 - treatment_accuracy: 0.8526 - track_epsilon: 0.0021 - val_loss: 182.4297 - val_regression_loss: 66.4488 - val_binary_classification_loss: 44.7961 - val_treatment_accuracy: 0.6600 - val_track_epsilon: 0.0021
df_preds = pd.DataFrame([s_ite.ravel(),
t_ite.ravel(),
x_ite.ravel(),
r_ite.ravel(),
dragon_ite.ravel(),
tau.ravel(),
treatment.ravel(),
y.ravel()],
index=['S','T','X','R','dragonnet','tau','w','y']).T
df_cumgain = get_cumgain(df_preds)
df_result = pd.DataFrame([s_ate, t_ate, x_ate, r_ate, dragon_ate, tau.mean()],
index=['S','T','X','R','dragonnet','actual'], columns=['ATE'])
df_result['MAE'] = [mean_absolute_error(t,p) for t,p in zip([s_ite, t_ite, x_ite, r_ite, dragon_ite],
[tau.values.reshape(-1,1)]*5 )
] + [None]
df_result['AUUC'] = auuc_score(df_preds)
df_result
| ATE | MAE | AUUC | |
|---|---|---|---|
| S | 4.054511 | 1.027666 | 0.575822 |
| T | 4.100199 | 0.980788 | 0.580929 |
| X | 4.021592 | 1.113436 | 0.564887 |
| R | 3.537158 | 1.901297 | 0.553742 |
| dragonnet | 4.011624 | 1.162131 | 0.556887 |
| actual | 4.098887 | NaN | NaN |
plot_gain(df_preds)
causalml Synthetic Data Generation Method¶y, X, w, tau, b, e = simulate_nuisance_and_easy_treatment(n=1000)
X_train, X_val, y_train, y_val, w_train, w_val, tau_train, tau_val, b_train, b_val, e_train, e_val = \
train_test_split(X, y, w, tau, b, e, test_size=0.2, random_state=123, shuffle=True)
preds_dict_train = {}
preds_dict_valid = {}
preds_dict_train['Actuals'] = tau_train
preds_dict_valid['Actuals'] = tau_val
preds_dict_train['generated_data'] = {
'y': y_train,
'X': X_train,
'w': w_train,
'tau': tau_train,
'b': b_train,
'e': e_train}
preds_dict_valid['generated_data'] = {
'y': y_val,
'X': X_val,
'w': w_val,
'tau': tau_val,
'b': b_val,
'e': e_val}
# Predict p_hat because e would not be directly observed in real-life
p_model = ElasticNetPropensityModel()
p_hat_train = p_model.fit_predict(X_train, w_train)
p_hat_val = p_model.fit_predict(X_val, w_val)
for base_learner, label_l in zip([BaseSRegressor, BaseTRegressor, BaseXRegressor, BaseRRegressor],
['S', 'T', 'X', 'R']):
for model, label_m in zip([LinearRegression, XGBRegressor], ['LR', 'XGB']):
# RLearner will need to fit on the p_hat
if label_l != 'R':
learner = base_learner(model())
# fit the model on training data only
learner.fit(X=X_train, treatment=w_train, y=y_train)
try:
preds_dict_train['{} Learner ({})'.format(
label_l, label_m)] = learner.predict(X=X_train, p=p_hat_train).flatten()
preds_dict_valid['{} Learner ({})'.format(
label_l, label_m)] = learner.predict(X=X_val, p=p_hat_val).flatten()
except TypeError:
preds_dict_train['{} Learner ({})'.format(
label_l, label_m)] = learner.predict(X=X_train, treatment=w_train, y=y_train).flatten()
preds_dict_valid['{} Learner ({})'.format(
label_l, label_m)] = learner.predict(X=X_val, treatment=w_val, y=y_val).flatten()
else:
learner = base_learner(model())
learner.fit(X=X_train, p=p_hat_train, treatment=w_train, y=y_train)
preds_dict_train['{} Learner ({})'.format(
label_l, label_m)] = learner.predict(X=X_train).flatten()
preds_dict_valid['{} Learner ({})'.format(
label_l, label_m)] = learner.predict(X=X_val).flatten()
learner = DragonNet(verbose=False)
learner.fit(X_train, treatment=w_train, y=y_train)
preds_dict_train['DragonNet'] = learner.predict_tau(X=X_train).flatten()
preds_dict_valid['DragonNet'] = learner.predict_tau(X=X_val).flatten()
actuals_train = preds_dict_train['Actuals']
actuals_validation = preds_dict_valid['Actuals']
synthetic_summary_train = pd.DataFrame({label: [preds.mean(), mse(preds, actuals_train)] for label, preds
in preds_dict_train.items() if 'generated' not in label.lower()},
index=['ATE', 'MSE']).T
synthetic_summary_train['Abs % Error of ATE'] = np.abs(
(synthetic_summary_train['ATE']/synthetic_summary_train.loc['Actuals', 'ATE']) - 1)
synthetic_summary_validation = pd.DataFrame({label: [preds.mean(), mse(preds, actuals_validation)]
for label, preds in preds_dict_valid.items()
if 'generated' not in label.lower()},
index=['ATE', 'MSE']).T
synthetic_summary_validation['Abs % Error of ATE'] = np.abs(
(synthetic_summary_validation['ATE']/synthetic_summary_validation.loc['Actuals', 'ATE']) - 1)
# calculate kl divergence for training
for label in synthetic_summary_train.index:
stacked_values = np.hstack((preds_dict_train[label], actuals_train))
stacked_low = np.percentile(stacked_values, 0.1)
stacked_high = np.percentile(stacked_values, 99.9)
bins = np.linspace(stacked_low, stacked_high, 100)
distr = np.histogram(preds_dict_train[label], bins=bins)[0]
distr = np.clip(distr/distr.sum(), 0.001, 0.999)
true_distr = np.histogram(actuals_train, bins=bins)[0]
true_distr = np.clip(true_distr/true_distr.sum(), 0.001, 0.999)
kl = entropy(distr, true_distr)
synthetic_summary_train.loc[label, 'KL Divergence'] = kl
# calculate kl divergence for validation
for label in synthetic_summary_validation.index:
stacked_values = np.hstack((preds_dict_valid[label], actuals_validation))
stacked_low = np.percentile(stacked_values, 0.1)
stacked_high = np.percentile(stacked_values, 99.9)
bins = np.linspace(stacked_low, stacked_high, 100)
distr = np.histogram(preds_dict_valid[label], bins=bins)[0]
distr = np.clip(distr/distr.sum(), 0.001, 0.999)
true_distr = np.histogram(actuals_validation, bins=bins)[0]
true_distr = np.clip(true_distr/true_distr.sum(), 0.001, 0.999)
kl = entropy(distr, true_distr)
synthetic_summary_validation.loc[label, 'KL Divergence'] = kl
df_preds_train = pd.DataFrame([preds_dict_train['S Learner (LR)'].ravel(),
preds_dict_train['S Learner (XGB)'].ravel(),
preds_dict_train['T Learner (LR)'].ravel(),
preds_dict_train['T Learner (XGB)'].ravel(),
preds_dict_train['X Learner (LR)'].ravel(),
preds_dict_train['X Learner (XGB)'].ravel(),
preds_dict_train['R Learner (LR)'].ravel(),
preds_dict_train['R Learner (XGB)'].ravel(),
preds_dict_train['DragonNet'].ravel(),
preds_dict_train['generated_data']['tau'].ravel(),
preds_dict_train['generated_data']['w'].ravel(),
preds_dict_train['generated_data']['y'].ravel()],
index=['S Learner (LR)','S Learner (XGB)',
'T Learner (LR)','T Learner (XGB)',
'X Learner (LR)','X Learner (XGB)',
'R Learner (LR)','R Learner (XGB)',
'DragonNet','tau','w','y']).T
synthetic_summary_train['AUUC'] = auuc_score(df_preds_train).iloc[:-1]
df_preds_validation = pd.DataFrame([preds_dict_valid['S Learner (LR)'].ravel(),
preds_dict_valid['S Learner (XGB)'].ravel(),
preds_dict_valid['T Learner (LR)'].ravel(),
preds_dict_valid['T Learner (XGB)'].ravel(),
preds_dict_valid['X Learner (LR)'].ravel(),
preds_dict_valid['X Learner (XGB)'].ravel(),
preds_dict_valid['R Learner (LR)'].ravel(),
preds_dict_valid['R Learner (XGB)'].ravel(),
preds_dict_valid['DragonNet'].ravel(),
preds_dict_valid['generated_data']['tau'].ravel(),
preds_dict_valid['generated_data']['w'].ravel(),
preds_dict_valid['generated_data']['y'].ravel()],
index=['S Learner (LR)','S Learner (XGB)',
'T Learner (LR)','T Learner (XGB)',
'X Learner (LR)','X Learner (XGB)',
'R Learner (LR)','R Learner (XGB)',
'DragonNet','tau','w','y']).T
synthetic_summary_validation['AUUC'] = auuc_score(df_preds_validation).iloc[:-1]
synthetic_summary_train
| ATE | MSE | Abs % Error of ATE | KL Divergence | AUUC | |
|---|---|---|---|---|---|
| Actuals | 0.484486 | 0.000000 | 0.000000 | 0.000000 | NaN |
| S Learner (LR) | 0.528743 | 0.044194 | 0.091349 | 3.473087 | 0.492660 |
| S Learner (XGB) | 0.315706 | 0.060831 | 0.348369 | 0.423556 | 0.575274 |
| T Learner (LR) | 0.493815 | 0.022688 | 0.019255 | 0.289978 | 0.610855 |
| T Learner (XGB) | 0.443124 | 0.359123 | 0.085374 | 0.785408 | 0.544435 |
| X Learner (LR) | 0.493815 | 0.022688 | 0.019255 | 0.289978 | 0.610855 |
| X Learner (XGB) | 0.364116 | 0.205326 | 0.248448 | 0.530261 | 0.554322 |
| R Learner (LR) | 0.473947 | 0.026080 | 0.021752 | 0.352075 | 0.613613 |
| R Learner (XGB) | 0.376447 | 0.499069 | 0.222997 | 0.847715 | 0.527988 |
| DragonNet | 0.408403 | 0.042945 | 0.157038 | 0.434239 | 0.612939 |
synthetic_summary_validation
| ATE | MSE | Abs % Error of ATE | KL Divergence | AUUC | |
|---|---|---|---|---|---|
| Actuals | 0.511242 | 0.000000 | 0.000000 | 0.000000 | NaN |
| S Learner (LR) | 0.528743 | 0.042236 | 0.034233 | 4.574498 | 0.494022 |
| S Learner (XGB) | 0.341574 | 0.066174 | 0.331874 | 0.779572 | 0.567832 |
| T Learner (LR) | 0.541503 | 0.025840 | 0.059191 | 0.686602 | 0.604712 |
| T Learner (XGB) | 0.467758 | 0.303262 | 0.085055 | 0.942250 | 0.550549 |
| X Learner (LR) | 0.541503 | 0.025840 | 0.059191 | 0.686602 | 0.604712 |
| X Learner (XGB) | 0.364071 | 0.164907 | 0.287869 | 0.648777 | 0.555098 |
| R Learner (LR) | 0.526938 | 0.029887 | 0.030702 | 0.739050 | 0.607303 |
| R Learner (XGB) | 0.428291 | 0.324614 | 0.162253 | 0.732380 | 0.536405 |
| DragonNet | 0.460422 | 0.041291 | 0.099405 | 0.843017 | 0.606557 |
plot_gain(df_preds_train)
plot_gain(df_preds_validation)