import numpy as np
from samIO import simulations as sim, perimeters as prm
from netCDF4 import Dataset

### GLOBAL CONTROL ### 
build_netcdf = True # Overwrite old/create netCDF
######################

def calcPerimeterValues():
    '''
    This script will loop through different simulations and their
    respective times, and determine the MSE and SSE for each height 
    level. The result is a netCDF file which contains these values 
    along with useful domain metrics (density, height, layer 
    thickness). 

    NOTE: This script makes use of variable length netCDF arrays.
    Perimeter MSE/SSE is a function of (time, z), where each
    time-height containing a 1D numpy array of float32 values with 
    ALL of the perimeter values. Compression/least_sig_digits for
    netCDF cannot be used for variable length datatypes.

    The masking is done manually below, with QN >= mask g/kg as the
    threshold. Masks can be change in the "masks" list.

    Read/write are controlled by the namelists files for the sam
    simulation EGP class type in samIO.simulations.

    If you want more cloud edge values, those will have to be
    entered manually in the bottom-most for loop (by height).
    '''

    # Establish the times and simulations over which
    # to calculate
    times = np.arange(2880,172801, 2880) # Matches 3D output files
    nl_directory = 'namelists/'
    namelists = ['nlRCE295FR.json','nlRCE300FR.json','nlRCE305FR.json']
    masks = [0.01] # in g/kg

    for nl in namelists:

        # Initialize simulation class and bring in relevent data
        print('Namelist: ', nl)
        sam = sim.EGP(f'{nl_directory}{nl}')
        sam.get_coordinates(times[0])
        sam.calc_dz()
        sam.rho = sam.load_stat('RHO')[0] # Same for all time
        sam.pres = sam.load_stat('p')

        nt = len(times)
        nz = len(sam.coordinates[2])

        for mask in masks:

            if isinstance(mask, str):
                m_str = 'QS'
            else:
                m_str = f'{mask:.3f}'
            print(f'Mask: {m_str}')
            # Path to write netCDF
            nc_Path = f'{sam.dSAM}OUT_CALC/EDGE/{sam.fname}_EDGE_M{m_str}.nc'

            # Build netcdf for ragged arrays
            if build_netcdf:
                try: ncf.close()
                except: pass
                ncf = Dataset(nc_Path, mode='w', format='NETCDF4')

                # Dimension Building
                ncf.createDimension('time', nt)
                ncf.createDimension('z', nz)

                # Coordinate Building
                time_var = ncf.createVariable('time', np.float32, ('time',))
                time_var.units = 'days'
                time_var.long_name = 'time'
                z_var = ncf.createVariable('z', np.float32, ('z',))
                z_var.units = 'meters (m)'
                z_var.long_name = 'Height'
                z_var[:] = sam.coordinates[2]
                dz_var = ncf.createVariable('dz', np.float32, ('z',))
                dz_var.units = 'meters (m)'
                dz_var.long_name = 'Layer Thickness'
                dz_var[:] = sam.dz
                rho_var = ncf.createVariable('RHO', np.float32, ('z',))
                rho_var.units = 'kg m^-3'
                rho_var.long_name = 'Air Density'
                rho_var[:] = sam.rho
                rho_var = ncf.createVariable('p', np.float32, ('z',))
                rho_var.units = 'mb'
                rho_var.long_name = 'Reference Pressure'
                rho_var[:] = sam.pres

                # This is the variable length datatype in netCDF
                ncf.createVLType(np.float32, 'perimeter')

                ncf.close()

            # Initialize the storage arrays to hold the ragged arrays
            edge_sse_tz = np.empty((nt, nz), object)
            edge_mse_tz = np.empty((nt, nz), object)
            day_time = np.zeros(nt, dtype=np.float32)
            
            for i, t in enumerate(times):

                # Bring in time-specific data
                print(f'Working on time: {i+1} of {len(times)}', end='\r')
                sam.get_coordinates(t)
                day_time[i] = sam.time # time keeping
                sam.SSE = sam.load_3D('SSE', t)
                sam.MSE = sam.load_3D('MSE', t)
                # Set Mask
                if isinstance(mask, str):
                    sam.QVSAT = sam.load_3D('QVSAT', t)
                    sam.MASK = (sam.load_3D('QN', t) >= 0.01*sam.QVSAT)
                else:
                    sam.MASK = (sam.load_3D('QN', t) >= mask)

                # Go through each height and return perim values
                for k in range(nz):

                    mask_layer = sam.MASK[:,:,k]

                    edge_mse_tz[i,k] = np.array(prm.get_perimeter_edge_values(mask_layer, sam.MSE[:,:,k]), dtype=np.float32)
                    edge_sse_tz[i,k] = np.array(prm.get_perimeter_edge_values(mask_layer, sam.SSE[:,:,k]), dtype=np.float32)

            # Add perimeter values to netCDF        
            # Write to existing netCDF
            ncf = Dataset(nc_Path, mode='r+')
            try:
                vlvar = ncf.vltypes['perimeter']
                ncMSE = ncf.createVariable('edge_mse', vlvar, ('time','z'))
                ncMSE.units = 'K'
                ncMSE.long_name = 'MSE at cloud edge'
                ncSSE = ncf.createVariable('edge_sse', vlvar, ('time','z'))
                ncSSE.units = 'K'
                ncSSE.long_name = 'SSE at cloud edge'
            except:
                ncMSE = ncf.variables['edge_mse']
                ncSSE = ncf.variables['edge_sse']
            ncMSE[:,:] = edge_mse_tz
            ncSSE[:,:] = edge_sse_tz
            ncTime = ncf.variables['time']
            ncTime[:] = day_time
            ncf.close()

            print(f'Mask {m_str} Completed.')

if __name__ == '__main__':
    calcPerimeterValues()