import numpy as np
from samIO import simulations as sim, perimeters as prm
from netCDF4 import Dataset
from scipy.stats import mode

def weighted_avg_and_std(values, weights):
    """
    Return the weighted average and standard deviation.

    They weights are in effect first normalized so that they 
    sum to 1 (and so they must not all be 0).

    values, weights -- NumPy ndarrays with the same shape.
    """
    average = np.average(values, weights=weights)
    # Fast and numerically precise:
    variance = np.average((values-average)**2, weights=weights)
    return (average, np.sqrt(variance))

def calcDomainStats():

###############################
######### PARAMETERS ##########
###############################
    namelists = ['nlRCE295FR.json','nlRCE300FR.json','nlRCE305FR.json']
    masks = [0.01] # What QN threshold should be used for the cloud mask
    times = np.arange(2880,172801, 2880) # Match output times of 3D files
    calc_weights = True # mass weighted or no weights?
    file_name_append = '3sims_1mask_weighted_fullrad'
###############################
###############################
###############################
    

    nt = len(times)
    nsim = len(namelists)
    nmask = len(masks)

    # Intialize arrays to store averages and standard deviations from simulations
    # 3D Domain Average: (sim, mask, time, (avgMSE, avgSSE, stdMSE, stdSSE))
    domain_stats = np.zeros((nsim, nmask, nt, 4), dtype=np.float32) 
    # 3D Domain Average across time: (sim, mask, (avgMSE, avgSSE, stdMSE, stdSSE))
    domain_stats_all = np.zeros((nsim, nmask, 4), dtype=np.float32)

    # Array for domain average saturation deficit
    # Last index: 0 - <h* - h>, 1 - <h*> - <h>
    sat_def = np.zeros((nsim, nmask, nt, 2), dtype=np.float32)
    sat_def_all = np.zeros((nsim, nmask, nt, 2), dtype=np.float32)

    for s, nl in enumerate(namelists):

        print(f'{nl}')
        # Initiate simulation instance and bring in edge dataset
        sam = sim.EGP(f'namelists/{nl}')
        npoints = sam.nx * sam.ny

        # Calculate statistics depending on cloud mask
        for q, mask in enumerate(masks):
            if isinstance(mask, str):
                m_str = 'QS'
            else:
                m_str = f'{mask:.3f}'
            print(f'Mask: {m_str}')

            # Determine weights for each height
            # rho*dz already calculated from calcEdgeValues.py and put in netCDF
            if calc_weights:
                nc_path = f'{sam.dSAM}OUT_EDGE/{sam.fname}_EDGE_M{m_str}.nc'
                edges = Dataset(nc_path, 'r')
                rho_dz_weight = np.empty((sam.nz,), np.float32)
                for k in range (sam.nz):
                    rho_dz_weight[k] = edges['RHO'][k] #* edges['dz'][k] # Cloud Add edges, will not align with 2D histograms
                edges.close()
            
            
            # Meat of the calculation. Here the weighted mean/std is calculated for each 3D snapshot
            # of the domain

            sse_all = [] # Avg./Std. of SSE for all times
            weights_all = [] # Mass weights for all times
            mse_all = [] # Avg./Std. of MSE for all times
            for t, time in enumerate(times):
                print(f'Time: {t+1} of {nt}', end='\r')

                # Load in data
                sam.SSE = sam.load_3D('SSE', time)
                sam.MSE = sam.load_3D('MSE', time)
                # Determine the appropriate cloud mask
                if isinstance(mask,str):
                    sam.QVSAT = sam.load_3D('QVSAT', time)
                    sam.MASK = (sam.load_3D('QN', time) >= 0.01*sam.QVSAT).astype(np.byte)
                else:
                    sam.MASK = (sam.load_3D('QN', time) >= mask).astype(np.byte)

                # Which heights are cloudy enough?
                cloudy_pct = np.sum(sam.MASK, axis=(0,1)) / npoints
                cldy_idx = [] # index heights where cloudy
                for k in range(sam.nz):
                    # Don't count layers with less than 0.1% cloud in the average
                    if cloudy_pct[k] >= 0.001:
                        cldy_idx.append(k) # array of heights that are cloudy
                cloud_minz = min(cldy_idx)
                cloud_maxz = max(cldy_idx) + 1
                
                # Create a list of weights
                if calc_weights:
                    weight_t = []
                    for k in range(cloud_minz, cloud_maxz):
                        weight_t.append([rho_dz_weight[k]] * npoints)
                else:
                    weight_t = None

                # Bring in SSE/MSE for appropriate cloudy heights
                sse_t = sam.SSE[:,:,cloud_minz:cloud_maxz].flatten('F')
                mse_t = sam.MSE[:,:,cloud_minz:cloud_maxz].flatten('F')
                # Ensure that the domain range we are averaging over has the same size as weights
                if calc_weights:
                    weight_t = np.concatenate(weight_t)
                    assert len(sse_t) == len(weight_t)

                # Calculated the mean and standard deviation of SSE and MSE
                # Place appropriately into the statistics array
                sse_bar, sse_std = weighted_avg_and_std(sse_t, weight_t)
                mse_bar, mse_std = weighted_avg_and_std(mse_t, weight_t)
                domain_stats[s,q,t,0] = mse_bar
                domain_stats[s,q,t,1] = sse_bar
                domain_stats[s,q,t,2] = mse_std
                domain_stats[s,q,t,3] = sse_std

                # Calculate the domain average saturation deficit 
                sat_def[s,q,t,0] = np.average(sse_t - mse_t, weights=weight_t) # <h* - h>
                sat_def[s,q,t,1] = sse_bar - mse_bar # <h*> - <h>

                sse_all.append(sse_t)
                weights_all.append(weight_t)
                mse_all.append(mse_t)

            # Now calculate the average of the domain across all timesteps
            print('Calculating across time...')
            sse_all = np.concatenate(sse_all)
            weights_all = np.concatenate(weights_all)
            mse_all = np.concatenate(mse_all)

            sit_bar, xit_std = weighted_avg_and_std(sse_all, weights_all)
            mit_bar, mit_std = weighted_avg_and_std(mse_all, weights_all)
            domain_stats_all[s,q,0] = mit_bar
            domain_stats_all[s,q,1] = sit_bar
            domain_stats_all[s,q,2] = mit_std
            domain_stats_all[s,q,3] = xit_std

            # Calculate the domain average saturation deficit across all times
            sat_def_all[s,q,0] = np.average(sse_all - mse_all, weights=weights_all) # <h* - h>
            sat_def_all[s,q,1] = sit_bar - mit_bar # <h*> - <h>


    np.savez(f'{sam.dWrite}RCE_stats_domain_{file_name_append}.npz', domain_stats=domain_stats, domain_stats_all=domain_stats_all, sd=sat_def, sda=sat_def_all)
    

if __name__ == '__main__':
    calcDomainStats()
