# -*- coding: utf-8 -*-
"""
Created on Thu Jan 26 16:24:24 2023

@author: Yutaro
"""
import numpy
import scipy
from enum import Enum
import glob

class SwitchingDirection(Enum):
    POSITIVE = 1
    NEGATIVE = 2
    
def vosc_fit(v_set, slope):
    return slope*v_set    
    
class PvsV:
    def __init__(self):
        self.i_dc = 100*1e-6 #Ampare
        self.pulse_width = 0.1 #ns
        self.pulse_number = 100 #times
        self.hx = 0 # Tesla
        self.hy = 0
        self.hz = 0
        self.i_init = 0 #Ampare
        self.v_pulse = [] #
        self.r_hall = [] #ohm
        self.r_after_init = []
        #self.r_threshold = 0 #ohm
        
        pass
        
    def load(self,file_path):        
        result_file = open(file_path, 'r')
        full_text_data = str(result_file.read())
        
        # Get params
        value = full_text_data.split("DC_Current(uA)\t")[1]
        value = value.split("\n")[0]
        self.i_dc = float(value) * 1e-6
        
        value = full_text_data.split("PulseWidth(ns)")[1]
        value = value.split("\n")[0]
        self.pulse_width = float(value)
        
        value = full_text_data.split("PulseNumber")[1]
        value = value.split("\n")[0]
        self.pulse_number = int(value)
        
        splitted_strs = full_text_data.split("Hx(mT)")
        if(len(splitted_strs) > 1):
            value = splitted_strs[1]
            value = value.split("\n")[0]
            self.hx = float(value)
        
        splitted_strs = full_text_data.split("Hy(mT)")
        if(len(splitted_strs) > 1):
            value = splitted_strs[1]
            value = value.split("\n")[0]
            self.hy = float(value)
        
        splitted_strs = full_text_data.split("Hz(mT)")
        if(len(splitted_strs) > 1):
            value = splitted_strs[1]
            value = value.split("\n")[0]
            self.hz = float(value)
        
        splitted_strs = full_text_data.split("I_init(uA)")
        if(len(splitted_strs) > 1):
            value = splitted_strs[1]
            value = value.split("\n")[0]
            self.i_init = float(value)
        
        self.v_pulse = [] #
        self.r_hall = [] #ohm
        self.r_after_init = []                                        
        # Load data
        data_list = full_text_data.split("R_after_init(ohm)\n")[1]
        data_list = data_list.split("\n")
        # Get length of data
        data_length = len(data_list)-1
        length_deta_set = round(data_length / self.pulse_number)
        
        for set_index in range(length_deta_set):
            voltage = numpy.zeros(self.pulse_number)
            hall_resistance = numpy.zeros(self.pulse_number)
            r_after_init = numpy.zeros(self.pulse_number)
            for index in range(self.pulse_number):
                line_index = set_index * self.pulse_number + index
                data_line = data_list[line_index].split("\t")
                voltage[index] = float(data_line[0])
                hall_resistance[index] = float(data_line[1])
                r_after_init[index] = float(data_line[2])
            self.v_pulse.append(voltage)
            self.r_hall.append(hall_resistance)
            self.r_after_init.append(r_after_init)

        result_file.close()        
        pass
    
    #
    def check_initialize(self, r_threshold, direction, r_ahe):
        for set_index in range(len(self.v_pulse)):
            checked_v = numpy.zeros(0)
            checked_r_hall = numpy.zeros(0)
            checked_r_init = numpy.zeros(0)
            
            for index in range(len(self.v_pulse[set_index])):
                if direction == SwitchingDirection.POSITIVE:
                    if self.r_after_init[set_index][index] > r_threshold:
                        checked_v = numpy.append(checked_v, (self.v_pulse[set_index])[index])
                        checked_r_hall = numpy.append(checked_v, (self.r_hall[set_index])[index])
                        checked_r_init = numpy.append(checked_v, (self.r_after_init[set_index])[index])
                elif direction == SwitchingDirection.NEGATIVE:
                    if self.r_after_init[set_index][index] < r_threshold:
                        checked_v = numpy.append(checked_v, (self.v_pulse[set_index])[index])
                        checked_r_hall = numpy.append(checked_r_hall, (self.r_hall[set_index])[index])
                        checked_r_init = numpy.append(checked_r_init, (self.r_after_init[set_index])[index])
            
            med = numpy.median(checked_r_init)
            if direction == SwitchingDirection.POSITIVE:
                checked_r_init = checked_r_init - med + r_ahe
                checked_r_hall = checked_r_hall - med + r_ahe
            elif direction == SwitchingDirection.NEGATIVE:
                checked_r_init = checked_r_init - med - r_ahe
                checked_r_hall = checked_r_hall - med - r_ahe
            self.v_pulse[set_index] = checked_v
            self.r_hall[set_index] = checked_r_hall
            self.r_after_init[set_index] = checked_r_init
        pass
    def save_hall_resistance(self, path):
        file = open(path, mode = 'w')
        for set_index in range(len(self.v_pulse)):
            for index in range(len(self.v_pulse[set_index])):
                v = self.v_pulse[set_index][index]
                r = self.r_hall[set_index][index]
                file.write(str(v) + '\t' + str(r) + '\n')
        file.close()
        pass
    
    def switching_probability_binary(self, r_threshold, switching_direction):
        voltage = numpy.zeros(len(self.v_pulse))
        probability = numpy.zeros(len(self.v_pulse))
        for set_index in range(len(self.v_pulse)):
            voltage[set_index] = self.v_pulse[set_index][0]
            if switching_direction == SwitchingDirection.POSITIVE:
                probability[set_index] = numpy.count_nonzero(self.r_hall[set_index] > r_threshold) / len(self.r_hall[set_index])
            elif switching_direction == SwitchingDirection.NEGATIVE:
                probability[set_index] = numpy.count_nonzero(self.r_hall[set_index] < r_threshold) / len(self.r_hall[set_index])
        return probability, voltage
    
    def binarize_auto(self):
        if(len(self.r_hall)==0):
            return False
        
        # Connect arrays        
        r_hall = self.r_hall[0].copy()
        for index in range(len(self.r_hall)-1):
            r_hall = numpy.append(r_hall, self.r_hall[index+1])

        #
        bins = 200
        hist = numpy.histogram(r_hall, bins = bins)
        peaks = scipy.signal.find_peaks(hist[0], distance = bins/2.0)[0]
        if(len(peaks) != 2):
            print("Failed to binarize...")
            return False
        r_ahe = (hist[1][peaks[1]] - hist[1][peaks[0]])/2.0
        r_mid = (hist[1][peaks[1]] + hist[1][peaks[0]])/2.0

        self.r_hall = self.r_hall - r_mid
        self.r_after_init = self.r_after_init - r_mid
        return True    
    
    def delete_uninitialized(self, r_threshold, direction):
        v = []
        r_h = []
        r_init = []
        for set_index in range(len(self.v_pulse)):
            checked_v = numpy.zeros(0)
            checked_r_hall = numpy.zeros(0)
            checked_r_init = numpy.zeros(0)
            
            for index in range(len(self.v_pulse[set_index])):
                if direction == SwitchingDirection.POSITIVE:
                    if self.r_after_init[set_index][index] > r_threshold:
                        checked_v = numpy.append(checked_v, (self.v_pulse[set_index])[index])
                        checked_r_hall = numpy.append(checked_r_hall, (self.r_hall[set_index])[index])
                        checked_r_init = numpy.append(checked_r_init, (self.r_after_init[set_index])[index])
                elif direction == SwitchingDirection.NEGATIVE:
                    if self.r_after_init[set_index][index] < r_threshold:
                        checked_v = numpy.append(checked_v, (self.v_pulse[set_index])[index])
                        checked_r_hall = numpy.append(checked_r_hall, (self.r_hall[set_index])[index])
                        checked_r_init = numpy.append(checked_r_init, (self.r_after_init[set_index])[index])
                    #else:
                        #print("Write error : taup=" + str(self.pulse_width)+\
                              #"ns, Vp="+ str(self.v_pulse[set_index][index]) + "V")
            v.append(checked_v)
            r_h.append(checked_r_hall)
            r_init.append(checked_r_init)
        self.v_pulse = v
        self.r_hall = r_h
        self.r_after_init = r_init
        pass
    
# Non-member functions
def save_p_v(voltage, probability, file_path):
    save_file = open(file_path, 'w')
    for index in range(len(voltage)):
        save_file.write(str(voltage[index]) + '\t' + str(probability[index]) + '\n')
    save_file.close()

class PulseProperty:
    def __init__(self):
        self.pulse_width_nominal = numpy.zeros(0)
        self.pulse_width = numpy.zeros(0)
        self.vosc_slope = numpy.zeros(0)# Vosc/Vset
        pass
        
    def load(self, file_path):
        result_file = open(file_path, 'r')
        full_text_data = str(result_file.read())
        
        #Get nominal pulse width
        nominal_tau = file_path.split('_')
        nominal_tau = (nominal_tau[-1]).split("ps.txt")[0]
        nominal_tau = float(nominal_tau)*1e-3
        
        # Load data
        data_list = full_text_data.split("\n")
        data_list.pop(0)        
        data_length = len(data_list)-1
        
        v_set = numpy.zeros(data_length)
        v_med = numpy.zeros(data_length)
        tau_med = numpy.zeros(data_length)
        
        for index in range(data_length):
            data_line = data_list[index]
            data = data_line.split("\t")
            v_set[index] = float(data[0])
            v_med[index] = float(data[2])
            tau_med[index] = float(data[6])
        result_file.close()
        
        # Linear fit
        params, covs = scipy.optimize.curve_fit(vosc_fit, v_set, v_med)
        slope = params[0]
        measured_tau = numpy.mean(tau_med)
        
        return nominal_tau, measured_tau, slope
    
    def load_full(self, folder_path):
        file_pathes = glob.glob(folder_path + '/' + '*.txt')
        num_of_files = len(file_pathes)
        print(num_of_files)
        #init data
        self.pulse_width_nominal = numpy.zeros(num_of_files, dtype=float)
        self.pulse_width = numpy.zeros(num_of_files)
        self.vosc_slope = numpy.zeros(num_of_files)
        
        #Load
        for index in range(num_of_files):
            a, b, c = self.load(file_pathes[index])
            self.pulse_width_nominal[index] = a
            self.pulse_width[index] = b
            self.vosc_slope[index] = c
        pass