import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import sys
from matplotlib.patches import FancyBboxPatch
from matplotlib.colors import Normalize
import matplotlib.patheffects as path_effects

# Set normalization for contour plot
class MidpointNormalize(Normalize):
    def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
        self.midpoint = midpoint
        Normalize.__init__(self, vmin, vmax, clip)

    def __call__(self, value, clip=None):
        # I'm ignoring masked values and all kinds of edge cases to make a
        # simple example...
        x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
        return np.ma.masked_array(np.interp(value, x, y))
norm = MidpointNormalize(midpoint=0)
props = dict(boxstyle='round,pad=0.5,rounding_size=0.2', facecolor='white', alpha=1.0)
e = 1.60217662e-19
# Set the font size globally
plt.rcParams.update({'font.size': 20})
NumberTimeSteps = 10000 # len(exec1_600pp_sim[1,:])
ave_step_size = 200
f = 13.56e6
dt = (1.0/NumberTimeSteps/f)
NumberTimeStepsAveraged = int(NumberTimeSteps/ave_step_size)
dt2 = (1.0/NumberTimeStepsAveraged/f)
Time2 = np.zeros(NumberTimeStepsAveraged)
for t in range(NumberTimeStepsAveraged):
    Time2[t] = dt2*t
Time1 = np.zeros(NumberTimeSteps)
for t in range(NumberTimeSteps):
    Time1[t] = dt*t
print(len(Time1))
print(len(Time2))

def density_plot2(pack, pack2, titles, x_label, y_label, colors, plot_name, markers, line_styles, o2_percentages_list):
    fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(20, 12), sharex=False)#changed for paper
    
    density_labels = [r'n$_e$', r'n$_{i,+}$',  r'n$_{i,-}$']
    ranges_pack = len(pack)
    #t = [int(Time2[27]/dt), int(Time2[30]/dt), int(Time2[33]/dt)]
    t= [5400, 6200, 7000]
    print(Time2[27]*dt)
    unique_o2_percentages = sorted(set(o2 for o2_list in o2_percentages_list for o2 in o2_list))
    print(unique_o2_percentages)
    profile_labels1 = ["a)" , "b)", "c)"]
    profile_labels2 = ["d)" , "e)", "f)"]
    colorpanel = ['red', 'green', 'blue']
    lines = ["solid", "dashed", "dotted"]  # Linestyles
    
    for idx in range(3): # this loop goes through the columns (we have 3 columns)
        ax_contourts = axes[0, idx] 
        ax_profilet1 = axes[1, idx] 

        ## chose parameterstudy 600 V is 0 and 700 V is 1
        pick = 1
        rho1 = pack[pick]
        ele1 = pack2[pick]
        save_plot_name = plot_name[pick]
        
        print(rho1[0].shape)
        print(rho1[1].shape)
        print(rho1[2].shape)
        # sys.exit()

        #print(Time[5400]*1e-9,Time[6000]*1e-9,Time[6600]*1e-9)
        # o2_percentages = o2_percentages_list[idx]
        # title = titles[idx]
        grid = np.linspace(0, 1, 201)
        
        # Let's do the contours       
        fac = 1e3
        r =  e*(rho1[idx]-ele1[idx])
        cax = ax_contourts.contourf(Time1 * 1e9, grid, r.T*fac, levels=100, cmap='seismic', norm=norm)
        
        
        # Define specific contour levels for electrons
        # Define specific contour levels
       
        if False:
            contour_levels = np.linspace(ele1[idx].min() * 1e-16, ele1[idx].max() * 1e-16, 15)
            contour = ax_contourts.contour(Time1 * 1e9, grid, ele1[idx].T * 1e-16, levels=contour_levels[1:], colors='black', linewidths=0.4)
            #ax_contourts.clabel(contour, inline=True, fontsize=15,fmt='%d', colors='k', inline_spacing=30)
            #ax_contourts.clabel(contour, inline=True, fontsize=15, fmt='%d', colors='k', inline_spacing=30, manual=True)


            x_range = np.linspace(10, 65, 4)
            y_range = np.linspace(0.1, 0.8, 4)

            # Create a checkerboard pattern of points
            P = [(x, y) for x in x_range for y in y_range]

            #ax_contourts.clabel(contour, colors='white', inline=True, fontsize=15, fmt='%d', manual=P)
            import matplotlib.patheffects as path_effects

            labels = ax_contourts.clabel(contour, colors='white', inline=True, fontsize=20, fmt='%d', manual=P)
            for label in labels:
                label.set_path_effects([path_effects.Stroke(linewidth=1.5, foreground='black'), path_effects.Normal()])


        ax_contourts.set_title(f'{unique_o2_percentages[idx]}% O$_2$',fontsize = 20) #changed for paper
        ax_contourts.set_xlabel('t$\,$[ns]')    #changed for paper
        ax_contourts.set_xlim(0, 70)
        ##############################  x-axis  ##############################
        ax_contourts.set_xticks([0,10,20,30,40,50,60,70])
        ######################### colorbar  #########################
        cbar = fig.colorbar(cax, ax=ax_contourts,pad=0.01)
        #if idx == 2:
        #    cbar.set_ticks([5,7,9,11])
        #    print("here")
        #if idx == 1:
        #    # 8 to 14
        #    cbar.set_ticks([10,11,12,13,14]) 
        #    print("here")
        #if idx == 0:
        #    # 9 to 17
        #    cbar.set_ticks([11,13,15,17])
        #    print("here")
        # Manually add a label above the colorbar
        #cbar.set_label(r'$[\rho_i]$', labelpad=10, rotation=0)
        ax_contourts.text(0.78, 1.025, r'$\rho$', transform=ax_contourts.transAxes, ha='center', fontsize=20)    #changed for paper
        ax_contourts.text(1.0, 1.02, r'[10$^{-3}$Asm$^{-3}$]', transform=ax_contourts.transAxes, ha='center', fontsize=20)       #changed for paper
        #ax_contourts.text(0.09, 1.025, r'n$_{\rm{e}}$', transform=ax_contourts.transAxes, ha='center', fontsize=20)    #changed for paper
        
        #vertical lines
        for i in range(len(t)):
            marker = markers[i % len(markers)]
            linestyle = line_styles[i % len(line_styles)]
            color = colors[i % len(colors)]
            ti = round(1e9 * dt * t[i], 2)
            #ax_contour1.axvline(x=ti, color=colorpanel[i],linestyle=linestyle, linewidth=3)
            ax_contourts.axvline(x=ti, color=colorpanel[i],linestyle=linestyle, linewidth=3)

        ave_ro = np.mean(rho1[idx], axis=0)
        ax_profilet1.plot(grid, ave_ro*1e-16, color="black",
                       linewidth=3, linestyle="solid", marker=marker, markevery=20, markersize=5)
        for i in range(len(t)):
            ti = round(1e9*dt*t[i],2)
            marker = markers[i % len(markers)]
            color = colors[i % len(colors)]

            #ax_profilet1.plot(grid, rho1[idx][t[i],:]*1e-16, color=colorpanel[i],
            #           linewidth=1, linestyle="solid", marker=marker, markevery=20, markersize=5)
            ax_profilet1.plot(grid, ele1[idx][t[i],:]*1e-16,  color=colorpanel[i],
                        linewidth=3, linestyle=lines[i], marker=marker, markevery=20)#, markersize=8)

        ax_profilet1.set_xlabel('x [mm]')
        ax_profilet1.set_xlim(0.2, 0.8)
        ax_profilet1.set_facecolor('white')  
        ax_profilet1.grid(True)
        if idx == 0:
            ax_contourts.set_ylabel('x [mm]')
            ax_profilet1.set_ylabel(y_label)
            
        if idx == 0: ax_profilet1.set_yticks([0,3,6,9,12,15,18])
        if idx == 1: ax_profilet1.set_yticks([0,3,6,9,12,15])
        if idx == 2: ax_profilet1.set_yticks([0,3,6,9,12])

        ax_contourts.text(0.05, 0.95, profile_labels1[idx], transform=ax_contourts.transAxes, fontsize=18, va='top', backgroundcolor='1.0', alpha=1.0, bbox=props)
        ax_profilet1.text(0.92, 0.95, profile_labels2[idx], transform=ax_profilet1.transAxes, fontsize=18, va='top', backgroundcolor='1.0', alpha=1.0, bbox=props)

    # Create custom legends
    # Electron Legend (Black lines)
    electron_handles = [Line2D([0], [0], color=colorpanel[i], lw=2, linestyle=lines[i], label=f'Time {i+1}')
                        for i in range(3)]
    # Ion Legend (Colored lines)
    ion_handles = [Line2D([0], [0], color="black", lw=2, linestyle='solid', label=f'Time {i+1}')
                   for i in range(3)]

    # legends
    handles = electron_handles + ion_handles
    labels = [
        rf"$n_{{\mathrm{{e}}}}(t={round(1e9 * Time1[t[i]], 1)})\,\mathrm{{ns}}$"
        for i in range(3)
    ] + [
        r"$\overline{n}_{i,\Delta}$"
    ]

    fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.98), ncol=6) #changed for paper

    # Adjust layout 
    fig.tight_layout(rect=[0, 0, 1, 0.93])
    fig.savefig(save_plot_name, bbox_inches='tight')
    #plt.show()
    plt.close(fig)

def load_data(path):
    # Load data from the specified path
    data_E = np.loadtxt(f'{path}/ReportDataElectronDensityXT.out')
    data_i1 = np.loadtxt(f'{path}/ReportDataIonDensity_1_XT.out')
    data_i2 = np.loadtxt(f'{path}/ReportDataIonDensity_2_XT.out')
    data_i3 = np.loadtxt(f'{path}/ReportDataIonDensity_3_XT.out')
    data_i4 = np.loadtxt(f'{path}/ReportDataIonDensity_4_XT.out')
    data_i5 = np.loadtxt(f'{path}/ReportDataIonDensity_5_XT.out')
    data_i6 = np.loadtxt(f'{path}/ReportDataIonDensity_6_XT.out')
    n_I_p = data_i1 + data_i2 + data_i4
    n_I_n = data_i3 + data_i5 + data_i6
    rho_i = n_I_p - n_I_n
    n_rel = np.mean(n_I_n)/np.mean(n_I_p)
    return [data_E, n_I_p, n_I_n], n_rel, rho_i

# Load your data
#pack_500_005, nrel_500_005, rho_i_net_500_005 = load_data('500pp/0.05')
#pack_500_01, nrel_500_01, rho_i_net_500_01 = load_data('500pp/0.1')
#pack_500_02, nrel_500_02, rho_i_net_500_02 = load_data('500pp/0.2')

pack_600_005, nrel_600_005, rho_i_net_600_005 = load_data('600pp/0.05')
pack_600_01, nrel_600_01, rho_i_net_600_01 = load_data('600pp/0.1')
pack_600_02, nrel_600_02, rho_i_net_600_02 = load_data('600pp/0.2')
pack_600_05, nrel_600_05, rho_i_net_600_05 = load_data('600pp/0.5')

pack_700_005, nrel_700_005, rho_i_net_700_005 = load_data('700pp/0.05')
pack_700_01, nrel_700_01, rho_i_net_700_01 = load_data('700pp/0.1')
pack_700_02, nrel_700_02, rho_i_net_700_02 = load_data('700pp/0.2')
pack_700_05, nrel_700_05, rho_i_net_700_05 = load_data('700pp/0.5')

#n_rel_500 = [nrel_500_005, nrel_500_01, nrel_500_02]
n_rel_600 = [nrel_600_005, nrel_600_01, nrel_600_02, nrel_600_05]
n_rel_700 = [nrel_700_005, nrel_700_01, nrel_700_02, nrel_700_05]

# Define colors, markers, and line styles
colors = ['red', 'green', 'white']  # Electrons, Positive Ions, Negative Ions

# Ensure enough markers and linestyles for O₂ percentages
markers = ['o', 's', '^', 'D', 'v', '>', '<', 'p']
line_styles = ['-', '--', '-.', ':', (0, (5, 1)), (0, (3, 1, 1, 1))]
#3.9823008849557524e-08 4.424778761061947e-08 4.8672566371681415e-08
# Prepare data for individual plots
#pack_pack_500 = [pack_500_005, pack_500_01, pack_500_02]
o2_percentages_500 = [0.05, 0.1, 0.5]
pack_pack_600 = [pack_600_005, pack_600_01, pack_600_05]
ele_pack_600 = [pack_600_005[0], pack_600_01[0], pack_600_05[0]]
o2_percentages_600 = [0.05, 0.1, 0.5]
pack_pack_700 = [pack_700_005, pack_700_01, pack_700_05]
ele_pack_700 = [pack_700_005[0], pack_700_01[0], pack_700_05[0]]
ele_pack_pack = [ele_pack_600, ele_pack_700]
o2_percentages_700 = [0.05, 0.1, 0.5]
#pack_500_rho = [rho_i_net_500_005, rho_i_net_500_01, rho_i_net_500_02]
pack_600_rho = [rho_i_net_600_005, rho_i_net_600_01, rho_i_net_600_05]
pack_700_rho = [rho_i_net_700_005, rho_i_net_700_01, rho_i_net_700_05]
#pack_rho = [pack_500_rho, pack_600_rho, pack_700_rho]
pack_rho2 = [pack_600_rho, pack_700_rho]
# Prepare data for the group plot

titles = [r'500 V$_{pp}$', r'600 V$_{pp}$', r'700 V$_{pp}$']
o2_percentages_list = [o2_percentages_500, o2_percentages_600, o2_percentages_700]

names_and_locations = ['chargedensity_plot_group_snaps600.png', 'chargedensity_plot_group_snaps700.png']


# Create the combined plot with density_plot2
density_plot2(pack_rho2, ele_pack_pack, titles, 'x [mm]', r'n [10$^{16}\,$m$^{-3}$]',
              colors, names_and_locations, markers, line_styles, o2_percentages_list)
