import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.model_selection import ShuffleSplit, GridSearchCV, LearningCurveDisplay
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline

# -----------------------------------------------------------------------------

def import_dataset(data_rqst):

    if data_rqst['target'] == 'T_tr':
        usecols = 'A, F:AC, AE:AH, AZ' # H_tr data is ignored
        samples_to_delete = [48, 74, 429, 662, 666, 667, 668, 850, 913, 915, 1031, 1043]
        samples_to_delete += samples_to_delete + [555, 669, 670, 671]       # nitraded samples (two papers)
        samples_to_delete += samples_to_delete + [992, 993, 994]            # superposition of Co and Mn (10.12693/APhysPolA.133.648)
        samples_to_delete += samples_to_delete + [72, 73, 74, 75, 76, 295]  # Al-containing samples (high amount, except 295)       
        name = r'$T_{tr}$'
        
        samples_to_delete += samples_to_delete + [492, 493]                # Dy
        samples_to_delete += samples_to_delete + [495, 496]                # Er
        samples_to_delete += samples_to_delete + [761]                     # Ho
        samples_to_delete += samples_to_delete + [1071, 1072, 1073, 1074]  # Tb
        samples_to_delete += samples_to_delete + [297]                     # V
        samples_to_delete += samples_to_delete + [450, 451, 452]           # Ca
        samples_to_delete += samples_to_delete + [1034]                    # P
        samples_to_delete += samples_to_delete + [1033]                    # Zn
        samples_to_delete += samples_to_delete + [1096]                    # Nb
        
    elif data_rqst['target'] == 'T_hys':
        usecols = 'A, F:AC, AE:AH, BA' # H_tr data is ignored
        samples_to_delete = []
        name = r'$T_{hys}$'
        
    elif data_rqst['target'] == 'benchmark':
        usecols = 'A, F:AC, AE:AH, AZ, BB:BC'
        samples_to_delete = []
        name = r'$|ΔS_M|_{max}$'
        
    elif data_rqst['target'] == 'S_max':
        usecols = 'A, F:AC, AE:AH, BB:BC'
        samples_to_delete = ['48_2.0', '48_5.0', 387, 428, 429, 576, '628_5.0', 804, 903, '1031_2.0', '1031_5.0', '1043_1.0', '1043_2.0']
        samples_to_delete += samples_to_delete + [1032, 1033, 1034]                         # 10.3390/met9040432
        samples_to_delete += samples_to_delete + [810, 811, 812]                            # |ΔS_M| peak has just few points (10.1007/s12598-011-0395-1)
        samples_to_delete += samples_to_delete + ['555_1.0','555_2.0','555_3.0','555_5.0']  # nitraded sample (10.1016/j.jmmm.2008.08.081)
        name = r'$|ΔS_M|_{max}$'

        samples_to_delete += samples_to_delete + ['492_1.0', '492_2.0', '492_4.0', '492_5.0', '493_1.0', '493_2.0', '493_4.0', '493_5.0']   # Dy
        samples_to_delete += samples_to_delete + ['495_1.0', '495_2.0', '495_3.0', '495_5.0', 496]   # Er
        samples_to_delete += samples_to_delete + ['761_1.0', '761_2.0', '761_4.0', '761_5.0']   # Ho
        samples_to_delete += samples_to_delete + [1071, 1072, 1073, 1074]  # Tb
        samples_to_delete += samples_to_delete + ['297_1.0', '297_2.0', '297_3.0', '297_4.0', '297_5.0']   # V
        samples_to_delete += samples_to_delete + ['450_2.0', '450_5.0', '451_2.0', '451_5.0', '452_2.0', '452_5.0']   # Ca
        samples_to_delete += samples_to_delete + [1034]   # P
        samples_to_delete += samples_to_delete + [1033]   # Zn
        samples_to_delete += samples_to_delete + [1096]   # Nb

    elif data_rqst['target'] == 'S_FWHM':
        usecols = 'A, F:AC, AE:AH, BB, BE'
        samples_to_delete = ['51_5.0', 61, 648]
        samples_to_delete += samples_to_delete + ['48_2.0', '48_5.0', 387, 428, 429, 576, '1031_2.0', '1031_5.0', '1043_1.0', '1043_2.0']
        samples_to_delete += samples_to_delete + [1032, 1033, 1034]                         # 10.3390/met9040432
        samples_to_delete += samples_to_delete + ['555_1.0','555_2.0','555_3.0','555_5.0']  # nitraded sample (10.1016/j.jmmm.2008.08.081)
        samples_to_delete += samples_to_delete + ['295_0.6','295_1.0','295_2.0']            # Al-containing sample (10.1109/20.951162)
        samples_to_delete += samples_to_delete + [660, 661, 662]                            # 10.1063/1.1448793   
        samples_to_delete += samples_to_delete + [389, 390]                                 # 10.1063/1.2203389
        samples_to_delete += samples_to_delete + [1053, 1054]                               # 10.1016/j.jmmm.2006.09.010 
        name = r'$|ΔS_M|_{FWHM}$'
        
        samples_to_delete += samples_to_delete + ['492_1.0', '492_2.0', '492_4.0', '492_5.0', '493_1.0', '493_2.0', '493_4.0', '493_5.0']   # Dy
        samples_to_delete += samples_to_delete + ['495_1.0', '495_2.0', '495_3.0', '495_5.0', 496]   # Er
        samples_to_delete += samples_to_delete + ['761_1.0', '761_2.0', '761_4.0', '761_5.0']   # Ho
        samples_to_delete += samples_to_delete + [1071, 1072, 1073, 1074]  # Tb
        samples_to_delete += samples_to_delete + ['297_1.0', '297_2.0', '297_3.0', '297_4.0', '297_5.0']   # V
        samples_to_delete += samples_to_delete + ['450_2.0', '450_5.0', '451_2.0', '451_5.0', '452_2.0', '452_5.0']   # Ca
        samples_to_delete += samples_to_delete + [1034]   # P
        samples_to_delete += samples_to_delete + [1033]   # Zn
        samples_to_delete += samples_to_delete + [1096]   # Nb

    elif data_rqst['target'] == 'T_ad':
        usecols = 'A, F:AC, AE:AH, BF:BG'
        samples_to_delete = []
        name = r'$T_{ad}$'
        
    else:
        print('\n\033[31m\033[1mERROR\033[0m: requested target doesn\'t recognised.')
        print('Available options: "T_tr", "T_hys", "S_max", "S_FWHM", "T_ad", or "benchmark"')
        return '', ''
    
    if data_rqst['anneling']:  usecols += ', AJ:AK'
    if data_rqst['phases']:    usecols += ', AL:AT'
    if data_rqst['lat_const']: usecols += ', AW:AX'
    if data_rqst['method']:    usecols += ', AI'
    
    if data_rqst['target'] == 'T_tr' and data_rqst['restore_T_tr']: 
        usecols += ', BB, BD, BF, BH'

    df = pd.read_excel(data_rqst['dataset_name'], index_col = 0, usecols = usecols)

    if data_rqst['method']:
        old_names = list(df.columns)
        i = old_names.index('Method')
        df = pd.get_dummies(df, columns = ['Method'])
        new_names = old_names[:i] + list(df.columns)[-4:] + old_names[i+1:]
        df = df[new_names]
        
    df, Xy_units = process_names_units(df)    
    
    if data_rqst['target'] == 'T_tr' and data_rqst['restore_T_tr']: 
        df, Xy_units = restore_T_tr(df, data_rqst, Xy_units)
    
    if data_rqst['target'] == 'S_max':
        df = process_lists(df,'HS','|ΔSM|max')
        
    if data_rqst['target'] == 'benchmark':
        df = process_lists(df,'HS','|ΔSM|max')
        
    if data_rqst['target'] == 'T_ad':   
        df = process_lists(df,'Had','|ΔTad|max')
        
    if data_rqst['target'] == 'S_FWHM':
        df = process_lists(df,'HS','ΔSM_FWHM')

    # deleting samples with missed targets
    for index, row in df.iterrows():
        if np.isnan(row.iloc[-1]): samples_to_delete.append(index)
    df = df.drop(samples_to_delete, axis = 'index')
    
    # deleting samples with missed features
    samples_to_delete = []
    if data_rqst['feature_gaps']:
        for index, row in df.iterrows():
            for feature in row[:-1]:
                if np.isnan(feature):
                    samples_to_delete.append(index)
                    break
    df = df.drop(samples_to_delete, axis = 'index')
    
    y = df[df.columns[-1]].values
    X = df.drop([df.columns[-1]], axis = 'columns').values
    Xy_names = list(df.columns)
    Xy_names[-1] = name
    x_ID = np.array(df.index)
    
    X = X.astype('float64')
    y = y.astype('float64')
    
    # rounding features
    for i in range(len(Xy_names)-1):
        if Xy_names[i] != 'a@PM' or Xy_names[i] != 'c@PM':
            X[:,i] = np.ndarray.round(X[:,i], data_rqst['rounding'])
            
    # rounding target
    y = np.ndarray.round(y, data_rqst['rounding'])
    
    x_ID, X, y = process_duplicates(x_ID, X, y, data_rqst['duplicates'])
    
    if data_rqst['feature_empty']:
        empty_found = False
        print('\nChecking for empty features...')
        for i in reversed(range(len(Xy_names)-1)):
            if np.sum(X[:,i]) == 0:
                empty_found = True
                X = np.delete(X, (i), axis = 1)
                print(f' - feature \033[31m\033[1m"{Xy_names[i]}"\033[0m is removed')
                Xy_names.pop(i)
                Xy_units.pop(i)
        if not empty_found:
            print('There are no empty features')
    
    N_sm = len(y)
    N_ft = len(Xy_names)-1    

    print('\nDataset from \033[31m\033[1m' + data_rqst['dataset_name'] + '\033[0m is uploaded:')
    print(f' - target is {data_rqst["target"]}')
    print(f' - {N_sm} samples')
    print(f' - {N_ft} features')
    
    return x_ID, X, y, Xy_names, Xy_units


# -----------------------------------------------------------------------------

def restore_T_tr(df, data_rqst, units, plot_comparison = False):

    T_tr = []
    T_peak = []
    
    print('\nRestoring missed transition temperatures:\n')
    
    j = 0
    for i in range(1,df.shape[0]+1):
    
        # taking the temperature of ΔS peak at the lowest H
        HS_list = []
        if isinstance(df['HS'][i], str):
            if df['HS'][i][0] != '[' or df['HS'][i][-1] != ']':
                print(f'\n \033[31m\033[1mERROR\033[0m: check brackets for H_S of the sample {i}\n')
            HS_list = [float(x) for x in df['HS'][i][1:-1].split(',')]
            
        TS_list = []
        if isinstance(df['T@|ΔSM|max'][i], str):
            if df['T@|ΔSM|max'][i][0] != '[' or df['T@|ΔSM|max'][i][-1] != ']':
                print(f'\n \033[31m\033[1mERROR\033[0m: check brackets for T@|ΔSM|max of the sample {i}\n')
            TS_list = [float(x) for x in df['T@|ΔSM|max'][i][1:-1].split(',')]
        
        TS_rest = float('NaN')
        if len(HS_list) + len(TS_list) != 0:
            if len(HS_list) == len(TS_list):
                if HS_list[0] <= data_rqst['restore_H_lim']:
                    TS_rest = TS_list[0]
            
        # taking the temperature of ΔTad peak at the lowest H
        HTad_list = []
        if isinstance(df['Had'][i], str):
            if df['Had'][i][0] != '[' or df['Had'][i][-1] != ']':
                print(f'\n \033[31m\033[1mERROR\033[0m: check brackets for H_ad of the sample {i}\n')
            HTad_list = [float(x) for x in df['Had'][i][1:-1].split(',')]
            
        TTad_list = []    
        if isinstance(df['T@|ΔTad|max'][i], str):
            if df['T@|ΔTad|max'][i][0] != '[' or df['T@|ΔTad|max'][i][-1] != ']':
                print(f'\n \033[31m\033[1mERROR\033[0m: check brackets for T@|ΔTad|max of the sample {i}\n')
            TTad_list = [float(x) for x in df['T@|ΔTad|max'][i][1:-1].split(',')]
            
        TTad_rest = float('NaN')
        if len(HTad_list) + len(TTad_list) != 0:
            if len(HTad_list) == len(TTad_list):
                if HTad_list[0] <= data_rqst['restore_H_lim']:
                    TTad_rest = TTad_list[0]
        
        T_rest = float('NaN')
        if np.isnan(TS_rest):
            if not np.isnan(TTad_rest):
                T_rest = TTad_rest
                log_out = f' {i}: T_tr ≈ {T_rest} K from T_ad peak at {HTad_list[0]} T'
        else:
            if np.isnan(TTad_rest):
                T_rest = TS_rest
                log_out = f' {i}: T_tr ≈ {T_rest} K from S peak at {HS_list[0]} T'
                
            else:
                if HS_list[0] < HTad_list[0]:
                    T_rest = TS_rest
                    log_out = f' {i}: T_tr ≈ {T_rest} K from S peak at {HS_list[0]} T'
                    
                elif HTad_list[0] < HS_list[0]:
                    T_rest = TTad_rest
                    log_out = f' {i}: T_tr ≈ {T_rest} K from T_ad peak at {HTad_list[0]} T'
                    
                else:
                    T_dif = round(abs(TS_rest - TTad_rest), 1)
                    T_rest = round((TS_rest + TTad_rest)/2, 1)
                    log_out = f' {i}: T_tr ≈ {T_rest} K as the average Temp. of S and T_ad peaks at {HS_list[0]} T (T diff. {T_dif} K)'
        
        if np.isnan(df['Ttr'][i]):
            if not np.isnan(T_rest):
                df.at[i, 'Ttr'] = T_rest
                print(log_out)
                j += 1
        else:
            if not np.isnan(T_rest):
                T_peak.append(T_rest)
                T_tr.append(df['Ttr'][i])
                T_dif = round(abs(T_rest - df['Ttr'][i]), 1)
                if T_dif > 20:
                    print(f' \033[31m\033[1mWARNING\033[0m: large T_dif of {T_dif} K for sample {i}')
    
    print(f'\nTransition temperatures are restored for {j} samples')

    if plot_comparison:
        plt.figure(figsize=(5,5), dpi = 500)
        plt.scatter(T_peak, T_tr, s = 10, color = 'blue', marker = 'o')
        plt.xlabel('Peak temperature (K)', size = 18)
        plt.ylabel('Transition temperature (K)', size = 18)
        plt.title(f'Magnetic field <= {data_rqst["restore_H_lim"]} T')
        
        max_y = np.max(T_peak + T_tr)*1.05
        min_y = np.min(T_peak + T_tr)*0.95
        
        plt.plot([min_y, max_y], [min_y, max_y], color = 'black', linewidth = 1)
        
        plt.gca().set_aspect('equal')
        plt.xlim(min_y, max_y)
        plt.ylim(min_y, max_y)

    df = df.drop(df.columns[-4:], axis = 'columns')
    units = units[:-4]
    
    return df, units


# -----------------------------------------------------------------------------

def process_lists(df, field, target):
  
    df_new = pd.DataFrame(columns = df.columns)
    
    for index, row in df.iterrows():
        if isinstance(row[field], str) and isinstance(row[target], str):
            if row[field][0] != '[' or row[field][-1] != ']':
                print(f'\n \033[31m\033[1mERROR\033[0m: check brackets or the last symbol in {field} of the sample {index}\n')
                return df
            
            HS_list = []
            for x in row[field][1:-1].split(','):
                try:
                    HS_list.append(float(x))
                except ValueError:
                    print(f'\n \033[31m\033[1mERROR\033[0m: conversion issue, check data format in {field} of the sample {index}\n')
                    return df
                    
            if row[target][0] != '[' or row[target][-1] != ']':
                print(f'\n \033[31m\033[1mERROR\033[0m: check brackets or the last symbol in {target} of the sample {index}\n')
                return df
            
            S_list = []
            for x in row[target][1:-1].split(','):
                try:
                    S_list.append(float(x))
                except ValueError:
                    print(f'\n \033[31m\033[1mERROR\033[0m: conversion issue, check data format in {target} of the sample {index}\n')
                    return df
    
            if len(HS_list) == len(S_list):
                for i in range(len(HS_list)):
                    index_aux = str(index) + '_' + str(HS_list[i])
                    df_new.loc[index_aux] = row
                    df_new.loc[index_aux, (field)] = HS_list[i]
                    df_new.loc[index_aux, (target)] = S_list[i]
            else:
                print(f'\n \033[31m\033[1mERROR\033[0m: length mismatch of HS and {target} data for sample {index}\n')
            
        else:
            df_new.loc[index] = row
            
    df_new = df_new.astype(df.dtypes.to_dict())
    print('\nLists in features and targets are processed')
    
    return df_new


# -----------------------------------------------------------------------------

def process_names_units(df):
    
    units = []
    for old_name in df.columns:
        if '(' and ')' in old_name:
            start = old_name.index('(')
            end = old_name.index(')')
            units.append(old_name[start+1:end])
            new_name = old_name[:start]
        else: 
            if '[' and ']' in old_name:
                start = old_name.index('[')
                new_name = old_name[:start]
            else:
                new_name = old_name
            units.append('nan')    
        
        if new_name[-1] == ' ': new_name = new_name[:-1]                
        new_name = new_name.replace(' ', '_')
        
        df = df.rename(columns = {old_name: new_name}, errors = 'raise')
   
    return df, units


# -----------------------------------------------------------------------------

def process_duplicates(x_ID, X, y, policy):
    
    N_sm = np.shape(X)[0]
    N_ft = np.shape(X)[1]
    
    y = y.reshape(len(y), 1)
    x_ID = x_ID.reshape(len(x_ID), 1)
    
    check = np.zeros((N_sm,1), dtype = int)
    
    dup_y = []
    dup_ID = []

    print('\nChecking for duplicating data:')

    for i in range(N_sm):
        if check[i] == 0:
            dup_y_aux = []
            dup_ID_aux = []
            for j in range(i+1,N_sm):
                if check[j] == 0:
                    if np.array_equal(X[i], X[j]):
                        check[i] = 1
                        check[j] = 1
                        if not dup_y_aux:
                            dup_y_aux.append(y[i,0])
                            dup_ID_aux.append(x_ID[i,0])
                        dup_y_aux.append(y[j,0])
                        dup_ID_aux.append(x_ID[j,0])
            if len(dup_y_aux) != 0:
                dup_ID.append(i)
                if policy == 'avr':
                    dup_y.append(np.average(dup_y_aux))
                elif policy == 'min':
                    dup_y.append(np.min(dup_y_aux))
                elif policy == 'max':
                    dup_y.append(np.max(dup_y_aux))
                print(' - \033[31m\033[1m' + str(dup_ID_aux) + '\033[0m are duplicates with targets \033[31m\033[1m' + str(dup_y_aux) + '\033[0m')
 
    if not dup_ID:
        print('There are no duplicates in the data\n')
    else:
        N_dup = len(dup_ID)
        dup_X = np.zeros((N_dup,N_ft), dtype = float)
        dup_x_ID = np.zeros((N_dup,1), dtype = float)
        for i in range(N_dup):
            dup_X[i,:] = X[dup_ID[i],:]
            dup_x_ID[i,:] = x_ID[dup_ID[i],:]

        for i in range(N_sm):
            j = N_sm - 1 - i    
            if check[j] == 1:
                X = np.delete(X, (j), axis = 0)
                y = np.delete(y, (j), axis = 0)
                x_ID = np.delete(x_ID, (j), axis = 0)

        dup_y = np.array(dup_y).reshape(len(dup_y), 1)

        X = np.concatenate((X,dup_X), axis = 0)
        y = np.concatenate((y,dup_y), axis = 0)
        x_ID = np.concatenate((x_ID,dup_x_ID), axis = 0)
    
        print(f'{np.sum(check) - len(dup_ID)} duplicates are deleted, corresponding targets are processed ({policy})')

    y = y.reshape(len(y))
    x_ID = x_ID.reshape(len(x_ID))
    
    return x_ID, X, y


# -----------------------------------------------------------------------------

def check_outliers(X, y, x_ID, Xy_units, Xy_names, model, outliers):

    N_sm = len(x_ID)
    out_MAE = np.ones((N_sm,outliers['N']), dtype = float)*float('NaN')
    out_MAE_mean = np.zeros((N_sm), dtype = float)

    sp = ShuffleSplit(n_splits = outliers['N'], test_size = outliers['test_size'], train_size = None, random_state = outliers['seed'])

    print('\nCollecting data for outliers check:')
    for i, (train, test) in enumerate(sp.split(X, y)):

        X_train, X_test = X[train], X[test]
        y_train, y_test = y[train], y[test]
        
        model.fit(X_train, y_train)
        
        y_pred = model.predict(X_test)
        MAE = mean_absolute_error(y_test, y_pred)
        
        for j in range(len(y_test)):
            out_MAE[test[j],i] = abs(y_test[j] - y_pred[j])
        
        print(' %2.0d: MAE = %.3f (%s)' % (i+1, MAE, Xy_units[-1]))

    for i in range(N_sm):
        mean = 0
        mean_N = 0
        for j in range(outliers['N']):
            if not np.isnan(out_MAE[i,j]):
                mean = mean + out_MAE[i,j]
                mean_N = mean_N + 1
        if mean_N != 0:
            mean = mean/mean_N
        out_MAE_mean[i] = mean

    index = np.argsort(out_MAE_mean, axis = 0)
    index = np.flipud(index)
    
    plt.figure(figsize=(5,5), dpi = 500)
    plt.title('Check for outliers based on extreme MAE')
    plt.xlabel('Sorted sample IDs', size = 14)
    plt.ylabel(f'{Xy_names[-1]} mean MAE ({Xy_units[-1]})', size = 14)
    plt.scatter(np.linspace(1,N_sm,N_sm), out_MAE_mean[index], s = 6, color = 'blue', marker = 'o')
    plt.show()
    
    N_show = 15
    print(f'\nThe worst {N_show} samples:')
    for i in range(N_show):
        print(f' - {x_ID[index][i]}: mean MAE = {out_MAE_mean[index][i]:.3f} ({Xy_units[-1]})')


# -----------------------------------------------------------------------------
    
def optimize(X, y, x_ID, Xy_units, Xy_names, model, regr):

    print('\nHyperparameters optimization, please wait...')    
    
    sp = ShuffleSplit(n_splits = regr['cv_splits'], test_size = regr['test_size'], train_size = None, random_state = regr['seed'])
    
    if regr['scaling']:
        pipeline = Pipeline([('scaler', StandardScaler()), ('regressor', model)])  
        search = GridSearchCV(estimator = pipeline, param_grid = regr['param_grid'], scoring = 'neg_mean_absolute_error', cv = sp, n_jobs = -1)
    else:
        search = GridSearchCV(estimator = model, param_grid = regr['param_grid'], scoring = 'neg_mean_absolute_error', cv = sp, n_jobs = -1)
        
    search.fit(X, y)

    print(f'\nOptimization has been finished (rs = {regr["seed"]}):')
    print(f' - best achieved MAE: \033[31m\033[1m{-search.best_score_:.2f} ({Xy_units[-1]})\033[0m')
    print(f' - best set of hyperparameters: \033[31m\033[1m{search.best_params_}\033[0m')
    
    results = search.cv_results_
    scores = results['mean_test_score']
    index = np.argsort(scores)
    scores_sorted = -scores[index]
    scores_std = results['std_test_score'][index]
    N_prm = len(index)
    
    plt.figure(figsize = (5,5), dpi = 500)
    plt.title('Optimization of hyperparameters ')
    plt.xlabel('Set of hyperparameters, ID', size = 14)
    plt.ylabel(f'{Xy_names[-1]} MAE ({Xy_units[-1]})', size = 14)
    x_featID = np.linspace(1,N_prm,N_prm)
    plt.fill_between(x_featID, scores_sorted - scores_std, scores_sorted + scores_std, color = 'red', alpha = 0.3, linewidth = 0)
    plt.scatter(x_featID, scores_sorted, s = 6, color = 'red', marker = 'o')
    plt.show()
    
    print('\nOptimization progress is plotted')
    
    print('\nCheck local sets of hyperparameters')
    
    N_sh = 2
    set_id = input('Enter set ID (n/N to quit): ')
    while set_id.lower() != 'n':
        for i in np.linspace(-N_sh, N_sh, 2*N_sh+1, dtype = int):
            j = int(set_id)+i
            print(f' - {j} set: {results["params"][index[j]]}, MAE = {-scores[index[j]]:.3f} ({Xy_units[-1]})')
        set_id = input('Enter set ID (n/N to quit): ')
     
        
# -----------------------------------------------------------------------------
     
def prediction_plot(y_train, y_fit, y_test, y_pred, i, Xy_names, Xy_units, rs, model_name):

    plt.figure(figsize=(5,5), dpi = 500)
    plt.scatter(y_train, y_fit, s = 10, color = 'blue', marker = 'o', label = 'Train data')
    plt.scatter(y_test, y_pred, s = 20, color = 'red', marker = 'o', label = 'Test data')
    plt.title(f'{model_name}: cv = {i}, rs = {rs}', size = 12)
    plt.xlabel(f'Experimental {Xy_names[-1]} ({Xy_units[-1]})', size = 14)
    plt.ylabel(f'Predicted {Xy_names[-1]} ({Xy_units[-1]})', size = 14)
    plt.legend(fontsize = 12)
    
    max_y = np.max([np.max(y_train), np.max(y_fit), np.max(y_test), np.max(y_pred)])
    if max_y > 0:
        max_y = max_y*1.05
    else:
        max_y = max_y*0.95
    
    min_y = np.min([np.min(y_train), np.min(y_fit), np.min(y_test), np.min(y_pred)])
    if min_y > 0:
        min_y = min_y*0.95
    else:
        min_y = min_y*1.05
    
    plt.plot([min_y, max_y], [min_y, max_y], color = 'black', linewidth = 1)
    
    plt.gca().set_aspect('equal')
    plt.xlim(min_y, max_y)
    plt.ylim(min_y, max_y)
    
    txt_x  = min_y + 0.75*(max_y - min_y)
    txt_y1 = min_y + 0.1*(max_y - min_y)
    txt_y2 = min_y + 0.17*(max_y - min_y) 
    txt_y3 = min_y + 0.24*(max_y - min_y)
    txt_kwargs = dict(ha = 'center', va = 'center', fontsize = 12)

    MAE = mean_absolute_error(y_test, y_pred)
    RMSE = mean_squared_error(y_test, y_pred, squared = False)
    R2  = r2_score(y_test, y_pred)

    plt.text(txt_x, txt_y1, r'$R^2$ = ' + f'{R2:.2f}', **txt_kwargs)
    plt.text(txt_x, txt_y2, f'RMSE = {RMSE:.2f} ({Xy_units[-1]})', **txt_kwargs)
    plt.text(txt_x, txt_y3, f'MAE = {MAE:.2f} ({Xy_units[-1]})', **txt_kwargs)
    
    plt.show()


# -----------------------------------------------------------------------------

def deviance(N_est, regr, X_test, y_test, y_pred):

    test_score = np.zeros((N_est,), dtype = float)
    for i, y_pred in enumerate(regr.staged_predict(X_test)):
        test_score[i] = mean_squared_error(y_test, y_pred)

    plt.figure(figsize = (5,5), dpi = 500)
    plt.plot(np.arange(N_est) + 1, regr.train_score_, 'b-', label = 'Training Set')
    plt.plot(np.arange(N_est) + 1, test_score, 'r-', label = 'Test Set')
    plt.legend(loc = 'upper right', fontsize = 14)
    plt.xlabel('Boosting Iterations', size = 14)
    plt.xlim(0, N_est*1.05)
    plt.ylabel('Deviance', size = 14)
    plt.yscale('log')
    

# -----------------------------------------------------------------------------
  
def feature_importance(Ft_imp, Xy_names, lim, i):

    N_ft = Ft_imp.shape[0]
    lim = N_ft - lim
    
    sorted_idx = np.argsort(Ft_imp)
    pos = np.arange(N_ft) + 0.5

    names = np.array(Xy_names[:-1])

    Ft_sorted = Ft_imp[sorted_idx]

    colormap = plt.cm.get_cmap('autumn_r')  
    colors = colormap(Ft_sorted[lim:]/100)
    print(Ft_sorted[lim:])

    plt.figure(figsize=(3,6), dpi = 500)
    plt.barh(pos[lim:], Ft_sorted[lim:], align='center', color = colors)
    plt.yticks(pos[lim:], names[sorted_idx[lim:]], size = 16)
    plt.xlabel('Feature importance', size = 19)
    plt.xticks(fontsize = 13)
    #plt.title('Feature importance', size = 16)
    #plt.title(f'Feature importance, cv = {i}', size = 16)
    
    Ft_cum = 0
    print('\nCumulative feature importance: ')
    for i in range(N_ft - lim):
        Ft_cum += Ft_imp[sorted_idx[N_ft-i-1]]
        print(f'{i+1}: {Ft_cum:.2f}  ')
    
  
# -----------------------------------------------------------------------------
  
def learning_curve(X, y, Xy_units, Xy_names, model, regr, model_name):

    sp = ShuffleSplit(n_splits = regr['cv_splits'], test_size = regr['test_size'], train_size = None, random_state = regr['seed'])
    
    fig, ax = plt.subplots(nrows = 1, ncols = 1, figsize = (5, 5), dpi = 500)
    
    param_lc = {
        "X": X,
        "y": y,
        "train_sizes": np.linspace(0.1, 1.0, 20),
        "cv": sp,
        "score_type": "both",
        "scoring": 'neg_mean_absolute_error',
        "negate_score": True,
        "score_name": f'MAE ({Xy_units[-1]})',
        "n_jobs": -1,
        "line_kw": {"marker": "o"},
        "std_display_style": "fill_between"
    }
    
    LearningCurveDisplay.from_estimator(model, **param_lc, ax = ax)
    ax.set_title(f'Learning curve for {Xy_names[-1]}')
    ax.xaxis.label.set_fontsize(14)
    ax.yaxis.label.set_fontsize(14)
    ax.title.set_fontsize(14)
    plt.show()


# -----------------------------------------------------------------------------

def train_model(X, y, Xy_units, Xy_names, model, regr):

    model_name = model.__class__.__name__
    
    sp = ShuffleSplit(
        n_splits = regr['cv_splits'],
        test_size = regr['test_size'],
        train_size = None,
        random_state = regr['seed'])
    
    RMSE = np.zeros((regr['cv_splits']), dtype = float)    
    MAE  = np.zeros((regr['cv_splits']), dtype = float)
    R2   = np.zeros((regr['cv_splits']), dtype = float)
    
    if model_name == 'GradientBoostingRegressor' or model_name == 'RandomForestRegressor':
        Ft_imp = np.zeros((np.shape(X)[1], regr['cv_splits']), dtype = float)
    
    print(f'\nTraining {model_name} (rs = {regr["seed"]}):')
    for i, (train, test) in enumerate(sp.split(X, y)):
    
        X_train, X_test = X[train], X[test]
        y_train, y_test = y[train], y[test]
    
        if regr['scaling']:
            scaler = StandardScaler()  
            scaler.fit(X_train)
            max_y = np.max(y_train)
            
            X_train = scaler.transform(X_train)  
            X_test = scaler.transform(X_test)
             
            y_train = y_train/max_y
            y_test = y_test/max_y
        else:
            max_y = 1.0
    
        model.fit(X_train, y_train)
        
        y_fit  = model.predict(X_train)*max_y
        y_pred = model.predict(X_test)*max_y
        y_train = y_train*max_y
        y_test = y_test*max_y
            
        RMSE[i] = mean_squared_error(y_test, y_pred, squared = False)
        MAE[i] = mean_absolute_error(y_test, y_pred)
        R2[i]  = r2_score(y_test, y_pred)
        Ft_imp[:,i] = model.feature_importances_
        
        prediction_plot(y_train, y_fit, y_test, y_pred, i+1, Xy_names, Xy_units, regr['seed'], model_name)
        print(' - %d: MAE = %.2f; RMSE = %.2f; R2 = %.2f' % (i, MAE[i], RMSE[i], R2[i]))
    
    if model_name == 'GradientBoostingRegressor' or model_name == 'RandomForestRegressor':
        mean_Ft_imp = np.mean(Ft_imp, axis = 1)
        mean_Ft_imp = 100*(mean_Ft_imp/np.sum(mean_Ft_imp))
        
        Ft_lim = 14
        feature_importance(mean_Ft_imp, Xy_names, Ft_lim, i+1)
    
    if model_name == 'GradientBoostingRegressor':
        deviance(regr['param_opt']['n_estimators'], model, X_test, y_test, y_pred)
    
    learning_curve(X, y, Xy_units, Xy_names, model, regr, model_name)    
    
    print('\033[31m\033[1m\nMean  MAE on test set is %.2f (%.2f) %s\033[0m' % (np.mean(MAE),  np.std(MAE),  Xy_units[-1]))
    print('\033[31m\033[1mMean RMSE on test set is %.2f (%.2f) %s\033[0m' % (np.mean(RMSE), np.std(RMSE), Xy_units[-1]))
    print('\033[31m\033[1mMean coefficient of determination: %.2f\033[0m' % np.mean(R2))
    
    
# ----------------------------------------------------------------------------

def tr_vs_ft(X, y, Xy_units, Xy_names, ind, N_points, show_bars = False):

    """ 
    Function plots the dependence of T_tr on feature X[ind] 
    with 'N_points' for averaging bars
    """       

    N_sm = np.shape(X)[0]
    
    b_lim = np.min(X[:,ind])
    t_lim = np.max(X[:,ind])   
    step = (t_lim-b_lim)/N_points
    
    rng = np.linspace(b_lim, t_lim, N_points+1)
    Y_points_mean = np.zeros(N_points)
    Y_points_std = np.zeros(N_points)
    
    for i in range(N_points):
        y_aux = []
        for j in range(N_sm):
            if i == 0:
                if rng[i] <= X[j,ind] <= rng[i+1]:
                    y_aux.append(y[j])
            else:
                if rng[i] <  X[j,ind] <= rng[i+1]:
                    y_aux.append(y[j])
        if y_aux != []:
            Y_points_mean[i] = np.mean(y_aux)
            Y_points_std[i]  =  np.std(y_aux)

    fg,ax = plt.subplots(figsize = (4,4), dpi = 500)
    if Xy_units[ind] == 'nan':
        plt.xlabel(f'{Xy_names[ind]} (at.%)', size = 22)
    else:
        plt.xlabel(f'{Xy_names[ind]} ({Xy_units[ind]})', size = 22)
    plt.ylabel(f'{Xy_names[-1]} ({Xy_units[-1]})', size = 22)
    ax.scatter(X[:,ind], y, c = 'k', s = 20)
    if show_bars:
        for i in range(N_points):
            ax.plot([rng[i], rng[i+1]], [Y_points_mean[i], Y_points_mean[i]], c = 'r', linewidth = 3)

    ax.set_xlim(b_lim - step/2, t_lim + step/2)
    ax.set_ylim(0, np.max(y)*1.05)
    plt.xticks(fontsize = 16)
    plt.yticks(fontsize = 16)
    plt.show()


# ----------------------------------------------------------------------------
    
def pcc_matrix(X, y, Xy_names):
    
    Xy = np.concatenate((X,y.reshape(len(y), 1)), axis = 1)

    df = pd.DataFrame(Xy, columns = Xy_names)
    df.head()
    cor = df.corr()
    
    plt.figure(figsize = (14,14), dpi = 500)
    sns.heatmap(cor, annot = True, cmap = 'bwr', fmt = "0.2f", cbar_kws = {"shrink": .9})
    #fig.set_xticklabels(fig.get_ymajorticklabels(), fontsize = 18)
    #fig.set_yticklabels(fig.get_ymajorticklabels(), fontsize = 18)
    plt.xticks(fontsize = 14)
    plt.yticks(fontsize = 14)
    plt.show()