import ml_1_13_functions as ml

import numpy as np
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.model_selection import ShuffleSplit

data_rqst = {
    "dataset_name" : '1-13_dataset.xlsx', # '1-13_dataset_rev20230616.xlsx'
    
    "target" : 'S_max',        # 'T_tr', 'T_hys', 'S_max', 'S_FWHM', 'T_ad', or 'benchmark'

    "method"    : False,      # include methods of synthesis into feature matrix
    "anneling"  : True,       # include annealing conditions
    "phases"    : False,      # include volume fractions of phases
    "lat_const" : False,      # include lattice constants

    "feature_gaps" : True,    # delete samples which have missed features
    "feature_empty" : True,   # delete features with no data

    "restore_T_tr"  : True,   # restore T_tr from the peak locations of entropy or T_ad data 
    "restore_H_lim" : 2,      # [T] max magnetic field at which T_tr can be restored
    
    "duplicates" : 'avr',     # ['avr'/'min'/'max'] policy for processing duplicates
    "rounding" : 2            # number of decimals to round features (except lat_const) and target
}

outliers = {
    'check': False,           # check for outliers based on extreme MAE of individual samples
    'N': 50,                  # number of training sessions to collect MAE statistics
    'test_size': 0.2,
    'seed': 17,               # random state
    'GBR_param': {'n_estimators': 200, 'max_depth': 4, 'min_samples_split': 4, 'learning_rate': 0.1}
}


#------------------------------------------------------------------------------

x_ID, X, y, Xy_names, Xy_units = ml.import_dataset(data_rqst)


if outliers['check']:
    gbr = GradientBoostingRegressor(**outliers['GBR_param'])
    ml.check_outliers(X, y, x_ID, Xy_units, Xy_names, gbr, outliers)


#--- Gradient Boosting Regression ---------------------------------------------

regr_gbr = {
    'opt': False,
    'param_opt': {},
    'param_grid': [
        {
            'n_estimators':  [50, 150, 200, 250, 300, 350, 400],
            'max_depth': [3, 4, 5, 6],
            'min_samples_split': [2, 3, 4],
            'learning_rate': [0.05, 0.1, 0.15, 0.2, 0.25]
         }
    ],
    'scaling': False,
    'test_size': 0.2,
    'cv_splits': 10,
    'seed': 15
}

if regr_gbr['opt']:
    gbr = GradientBoostingRegressor(random_state = regr_gbr['seed'])
    ml.optimize(X, y, x_ID, Xy_units, Xy_names, gbr, regr_gbr)

if data_rqst["target"] == 'T_tr':
    regr_gbr['param_opt'] = {'n_estimators': 200, 'max_depth': 4, 'min_samples_split': 4, 'learning_rate': 0.1}
    
elif data_rqst["target"] == 'T_hys':
    regr_gbr['param_opt'] = {'n_estimators': 150, 'max_depth': 4, 'min_samples_split': 3, 'learning_rate': 0.1}
    
elif data_rqst["target"] == 'S_max':
    regr_gbr['param_opt'] = {'n_estimators': 400, 'max_depth': 5, 'min_samples_split': 4, 'learning_rate': 0.15}
    
elif data_rqst["target"] == 'S_FWHM':
    regr_gbr['param_opt'] = {'n_estimators': 300, 'max_depth': 5, 'min_samples_split': 3, 'learning_rate': 0.1}
    
elif data_rqst["target"] == 'T_ad':
    regr_gbr['param_opt'] = {'n_estimators': 150, 'max_depth': 4, 'min_samples_split': 3, 'learning_rate': 0.1}

gbr = GradientBoostingRegressor(**regr_gbr['param_opt'])
    
ml.train_model(X, y, Xy_units, Xy_names, gbr, regr_gbr)

sp = ShuffleSplit(
    n_splits = regr_gbr['cv_splits'],
    test_size = regr_gbr['test_size'],
    train_size = None,
    random_state = regr_gbr['seed'])

RMSE = np.zeros((regr_gbr['cv_splits']), dtype = float)    
MAE  = np.zeros((regr_gbr['cv_splits']), dtype = float)
R2   = np.zeros((regr_gbr['cv_splits']), dtype = float)

print(f'\nTraining (rs = {regr_gbr["seed"]}):')
for i, (train, test) in enumerate(sp.split(X, y)):

    X_train, X_test = X[train], X[test]
    y_train, y_test = y[train], y[test]

    gbr.fit(X_train, y_train)
    
    y_fit  = gbr.predict(X_train)
    y_pred = gbr.predict(X_test)

    RMSE[i] = mean_squared_error(y_test, y_pred, squared = False)
    MAE[i] = mean_absolute_error(y_test, y_pred)
    R2[i]  = r2_score(y_test, y_pred)
    
    print(' - %d: MAE = %.2f; RMSE = %.2f; R2 = %.2f' % (i, MAE[i], RMSE[i], R2[i]))
    
    if i == 4: break