import json
import numpy as np
from netCDF4 import Dataset
#from xarray import open_dataset
from pickle import load
from functools import reduce
from numba import njit


class EGP:

    def __init__(self, namelist, analyze_shells=False):

        with open(namelist) as jfile:
            data = json.load(jfile)

        # Directory/filesystem 
        self.dSAM = data['dSAM']
        self.dWrite = data['dWrite']
        self.fname = data['fname']
        self.ntimedig = data['ntimedig']

        try:
            self.stat_name = data['stat']
            self.stat = f'{self.dSAM}OUT_STAT/{self.stat_name}.nc'
        except:
            self.stat = f'{self.dSAM}OUT_STAT/{self.fname}.nc'

        self.xyz_rev = data['xyz_rev']
        if self.xyz_rev == 'true':
            self.xyz_rev = True
        else:
            self.xyz_rev = False

        self.var_sep = data['var_sep']
        if self.var_sep == 'true':
            self.var_sep = True
        else:
            self.var_sep = False

        # Analysis variables and scope (time)
        if analyze_shells:
            self.conserved = data['conserved_variable']
            self.mask = data['mask']
            self.shell_codes = data['shell_codes']

        # Simulation information
        self.dx = data['dx']
        self.dy = data['dy']
        self.dz = data['dz']
        self.nx = data['nx']
        self.ny = data['ny']
        self.nz = data['nz']
        
        
    def load_stat(self, var):
        '''
        Loads the statistics file for the simulation
        
        Pass either a single variable as a string, or an array of strings for
        multiple variables to be returned as a list
        '''
        
        stat = Dataset(self.stat, 'r')
        stat.set_auto_mask(False)
        
        # Gather single variable data or list of variables data
        if isinstance(var, str):
            stat_data = stat[var][:]
            stat.close()
            return stat_data
        
        else:
            stat_data = []
            for v in var:
                stat_data.append(stat[v][:])
                
            stat.close()
            return stat_data

        
    def load_2D(self, var, time, simtime=False):
        '''
        Loads a 3D netcdf file given the variable name and time.

        var can be the string of the variable to be returned or can be a list of
        the variables names. Also can return non 2D variables from the data files
        (x or y or time)

        Note: TWPICE data is (t,x,y,z) shape, [0] is in place in array
        assignment to remove time index. May need to change for other data.
        simtime will update the class's current time to the time contained
        in the netcdf file.
        '''
        
        time0 = f'{time:0{self.ntimedig}d}' # format with leading zeros
        path = f'{self.dSAM}OUT_2D/{self.fname}_{time0}.nc'
        
        dat = Dataset(path,'r')
        dat.set_auto_mask(False)

        if isinstance(var, str):
            var2D = dat[var][:]
            if len(var2D.shape) > 1: # Remove time
                var2D = var2D[0]
                if self.xyz_rev: # Reshape to (x, y) if 2D variable
                    var2D = var2D.transpose((1, 0))
        
        else:
            var2D = []
            for v in var:
                pull = dat[v][:]
                if len(pull.shape) > 1: # Remove time dim from variables with time
                    pull = pull[0]
                    if self.xyz_rev: # Also put multidimensional variables into (x, y)
                        pull = pull.transpose((1, 0))
                var2D.append(pull)
            
        dat.close()
        
        return var2D


    def load_3D(self, var, time, mask=False):
        '''
        Loads a 3D netcdf file given the variable name and time.

        Note: TWPICE data is (t,x,y,z) shape, [0] is in place in array
        assignment to remove time index. May need to change for other data.
        '''
        
        time0 = f'{time:0{self.ntimedig}d}' # format with leading zeros
        # Based on SAM's write out directory structure
        if self.var_sep:
            path = f'{self.dSAM}OUT_3D.{var}/{self.fname}_{var}_{time0}.nc'
        else:
            path = f'{self.dSAM}OUT_3D/{self.fname}_{time0}.nc'
            
        dat = Dataset(path,'r')
        if not mask:
            dat.set_auto_mask(False)

        if isinstance(var, str):
            var3D = dat[var][:]
            if len(var3D.shape) > 1:
                var3D = var3D[0]
                if self.xyz_rev:
                    var3D = var3D.transpose((2,1,0))

        else:
            var3D = []
            for v in var:
                pull = dat[v][:]
                if len(pull.shape) > 1:
                    pull = pull[0]
                    if self.xyz_rev:
                        pull = pull.transpose((2,1,0))
                var3D.append(pull)

        dat.close()

        return var3D
    
    def loadpkl_3D(self, var, time):
        '''
        Loads a pickle file given the variable name and time.
        '''

        time0 = f'{time:0{self.ntimedig}d}'
        path = f'{self.dSAM}OUT_3D.{var}/{self.fname}_{var}_{time0}.p'
        with open(path, 'rb') as fl:
            dat = load(fl)

        return np.array(dat)

    # def loadxr(self, var, time):
    #     '''
    #     Loads a netcdf file given the variable name and time using xarray.
    #     xarray is actually dumb and annoying and slow. Would not use.
    #     '''

    #     path = f'{self.dSAM}OUT_3D.{var}/{self.fname}_{var}_000000{time}.nc'
    #     dat = open_dataset(path, engine='netcdf4', mask_and_scale=False)

    #     return dat

    def load_mask(self, time):
        '''
        Loads the provided mask type and time as a netCDF.

        Note: TWPICE data is (t,x,y,z) shape, [0] is in place in array
        assignment to remove time index. May need to change for other data.
        '''

        path = f'{self.dSAM}MASKS/{self.fname}_MASK{self.mask["id"]}_000000{time}.nc'
        dat = Dataset(path,'r')
        dat.set_auto_mask(False)
        v = dat['mask'][0] #(x,y,z) CHANGE IF UNDERLYING NETCDF IS DIFFERENT SHAPE
        dat.close()

        return v
    
    def get_coordinates(self, time):
        '''
        Class method loads coordinate data (x,y,z) and time into attributes
        '''

        time0 = f'{time:0{self.ntimedig}d}'
        if self.var_sep:
            path = f'{self.dSAM}OUT_3D.U/{self.fname}_U_{time0}.nc'
        else:
            path = f'{self.dSAM}OUT_3D/{self.fname}_{time0}.nc'
        dat = Dataset(path,'r')
        dat.set_auto_mask(False)
        x = dat['x'][:]
        y = dat['y'][:]
        z = dat['z'][:]
        if self.var_sep:
            pres = dat['pres'][:]
        else:
            pres = dat['p'][:]
        time = dat['time'][:]
        dat.close()

        self.coordinates = (x, y, z)
        self.pres = pres
        self.time = time
    
    def calc_dxi(self):
        '''
        Class method to create dxi attribute. Takes the dx and dy from self,
        (assigned from json namelist) and dz from coordinates pulled from 
        self.coordinates in load() method. Note dz is actually height values,
        not differences.

        Currently works for dx,dy = const. dz = variable
        '''

        if hasattr(self, 'coordinates'):
            self.dxi = (self.dx, self.dy, self.coordinates[-1])
        else:
            print('No Coordinate Data!')


    def calc_dz(self):
        '''
        Note entirely sure why this works but it does
        '''
        if hasattr(self, 'coordinates'):
            z = self.coordinates[-1]
            self.dz = np.zeros_like(z)
            for k in range(len(z)):
                if k == 0:
                    self.dz[k] = 2*z[k]
                else:
                    self.dz[k] = z[k] - z[k-1]
                    # If z is middle of levels, Why isn't it this?
                    # self.dz[k] = 2*(z[k] - z[k-1] - dz[k-1]/2)
        else:
            print('No Coordinate Data!')

    def calc_vorticity(self):
        '''
        Class method to prepare data for shells.calc_vorticity() function, call and assignment.
        '''

        if hasattr(self,'U') and hasattr(self,'V') and hasattr(self,'W') and hasattr(self,'dxi'):

            shape = self.U.shape
            vel = np.array([self.U, self.V, self.W])

            self.vorticity = calc_vorticity(vel,self.dxi,shape)

        else:
            print('Not all three velocity fields present!')

    def calc_shell_count(self, time):
        '''
        
        '''

        # Get relevant data
        mask = self.load_mask(time)
        self.get_coordinates(time)

        # Calculate the count of each shell type for all vertical levels
        shell_count = [] # will be (z, # shells)
        for cat, code in self.shell_codes.items():
            count = np.count_nonzero((mask == code), axis=(0,1))
            shell_count.append(count)

        # Check to make sure that all gridpoints have been counted
        total_shell_count = np.sum(shell_count)
        total_gridpoints = reduce(lambda a, b: a*b, mask.shape) # nx*ny*nz
        if total_shell_count != total_gridpoints:
            print('NOT ALL GRIDPOINTS ACCOUNTED FOR!')

        return np.array(shell_count)

        
##### FUNCTIONS #####

def create_threshold_mask(var, threshold):
    '''
    Given a variable array, will return a binary array
    of the same shape where var >= threshold -> 1 and
    var < threshold -> 0
    threshold should be a numerical value associated with
    the values of the variable array.
    '''

    mask = (var >= threshold).astype('byte')

    return mask

# @njit() - Does not work for np.gradient
def calc_vorticity(vel, diff, shape, periodic=True):
    '''
    Calculates the vorticity for SAM's 3D velocity field. Ensure that all vorticity
    fields have been colocated to the same location. Utilizes np.gradient() which uses
    a central differencing scheme. However for periodic boundary conditions, central differencing must be done manually.

    vel - List/iterable of length 3 with U, V and W 3D fields, respectively
    diff - List/iterable of length 3 with dx, dy and dz. Currently works for dx, dy = cont.
    shape - list/iterable of length 3 giving the size of the 3D fields (U.shape)==(V.shape)...

    Note that dz can be the vertical coordinates of the gridpoints, np.gradient() will
    calculate the associated dz in dUi/dz.

    Returns an ndarray of shape (3, shape[0], shape[1], shape[2]) with vorticity components
    for each gridpoint.
    '''
    print('Calculating Vorticity...')

    # Calculate the derivative of velocities in each direction (3,3,x,y,z)
    dui_dxi = np.zeros((3,3,shape[0],shape[1],shape[2]))
    for i in range(3): # U,V,W
        for j in range(3): # dx,dy,dz

            if i != j: # We can ignore diagonals since they are not in cross product
                dui_dxi[i,j] = np.gradient(vel[i],diff[j],axis=j)

                # Periodic Boundary Cases (Central Differencing)
                if periodic:
                    if j == 0:
                        dui_dxi[i,j,0,:,:] = (vel[i][1,:,:]-vel[i][-1,:,:]) / (2*diff[j])
                        dui_dxi[i,j,-1,:,:] = (vel[i][0,:,:]-vel[i][-2,:,:]) / (2*diff[j])
                    if j == 1:
                        dui_dxi[i,j,:,0,:] = (vel[i][:,1,:]-vel[i][:,-1,:]) / (2*diff[j])
                        dui_dxi[i,j,:,-1,:] = (vel[i][:,0,:]-vel[i][:,-2,:]) / (2*diff[j])

        print(f'{i+1} of 3 derivatives done.')
    print('Done.')

    # Calculate cross product components (3,x,y,z)
    vorticity = np.zeros((3,shape[0],shape[1],shape[2]))
    for i in range(3):
        if i == 0:
            vorticity[i] = dui_dxi[2,1] - dui_dxi[1,2] # dW/dy - dV/dz
        if i == 1:
            vorticity[i] = dui_dxi[0,2] - dui_dxi[2,0] # dU/dz - dW/dx
        if i == 2:
            vorticity[i] = dui_dxi[1,0] - dui_dxi[0,1] # dV/dx - dU/dy
        print(f'{i+1} of 3 components done.')

    print('Done.')

    return vorticity

#@njit() # can't use nonregular transpose in njit
def interpolate_3D_velocities(vector, axis=0, periodic=True): # Keep it simple, stupid
    '''
    Given a 3D-array of SAM data, will linearly interpolate (average)
        the field along the specified axis. This will colocate vectors to
        scalar positions (and scalars to vector positions).
        
        For example, when the U-component velocity field is passed with the
        shape (x,y,z). Specifying axis=0 will interpolate U along the x
        direction. If V with dims (x,y,z) and axis=1, V will interpolate along
        y direction, etc. Specifying U and axis=1 will interpolate U to the edge
        of the cube (the tau positions where i!=j)
        
        Periodic condition allows interpolation between first and last values
        along the axis. This works for periodic in X or Y. If non-periodic,
        then Z (or W) is interpolated where the top of atmosphere is 0.

        NOTE: Assuming in ARAKAWA-C all velocity positions LEAD the scalar
        by 1/2 a gridpoint. For example, U[i,j,k] and U[i+1,j,k] would 
        average to the position of psi[i,j,k]. (psi is a scalar field)

        vector - 3D ndarray
        axis - int of values (0,1,2)
        periodic - boolean, True interpolates 2D array around the boundary
                    False interpolated 2D array with 0 (i.e. average with ground)

        DOES NOT WORK FOR NON-PERIODIC X-Y DIRECTIONS YET, WOULD JUST ADD 0'S
        '''
    
    # Transpose 3D array based on axis of interpolation
    # Gets the interpolation direction on the first axis
    index = [0,1,2]
    shift_axis = np.roll(index, -axis)
    vector = vector.transpose(shift_axis)

    # Interpolate to midpoints of all remaining points
    vector[:-1,:,:] = 0.5*(vector[:-1,:,:] + vector[1:,:,:]) # shortened by 1 along interp axis

    # Deal with edge cases ("last" gridpoints in domain)
    if periodic:
        vector[-1,:,:] = 0.5*(vector[0,:,:] + vector[-1,:,:])
    else:
        # Top of atmosphere is 0
        vector[-1,:,:] = 0

    # Reverse original transpose (back to original shape)
    reverse_axis = np.roll(index, axis)
    vector = vector.transpose(reverse_axis)

    return vector

# INPROGRESS
# def interpolate_arakawa_corners(field, periodic=True):
#     '''
#     Interpolation function for 3D Arakawa C-grid. Uses
#     scalar values at each cube corner to determine value
#     at cube center.
#     '''
#     datatype = field.dtype
#     dims = field.shape
#     center_values = np.zeros_like((dims[0]+2, dims[1]+2, dims[2])).astype(datatype)
#     for i in range(dims[0]):
#         for j in range(dims[1]):
#             for k in range(dims[2]):
                
#                 centers[i+1,j+1,k] = 


def build_3D_netCDF(path, coords, var_props=None, enum_dict=None, fill=None):
    '''
    var_props - list of strings ['name', 'units', 'long name', datatype]
    enum_dict - dictionary of strings for enum type (i.e. shell codes)

    NOTE: SAM often writes out a time dimension even if there is only one
    timestep in the netCDF file. For continuity purposes, we are going
    to preserve that dimensionality in the netCDF files we create. However,
    the time dimension will be unlimited, in case we wish to use this 
    function for a netCDF file with many timesteps.
    '''
    # Build netCDF to write
    try: ncf.close()
    except: pass
    ncf = Dataset(path, mode='w',format='NETCDF4',fill_value=fill)
                     
    # Dimension Building
    ncf.createDimension('time', None) # Unlimited Dim for assignment
    ncf.createDimension('x', len(coords[0]))
    ncf.createDimension('y', len(coords[1]))
    ncf.createDimension('z', len(coords[2]))
                     
    # Coordinate Building
    time_var = ncf.createVariable('time', np.float32, ('time',))
    time_var.units = 'days'
    time_var.long_name = 'time'
    sub_x_var = ncf.createVariable('x', np.float32, ('x',))
    sub_x_var.units = 'meters (m)'
    sub_x_var.long_name = 'X-Coordinate'
    sub_x_var[:] = coords[0]
    sub_y_var = ncf.createVariable('y', np.float32, ('y',))
    sub_y_var.units = 'meters (m)'
    sub_y_var.long_name = 'Y-Coordinate'
    sub_y_var[:] = coords[1]
    z_var = ncf.createVariable('z', np.float32, ('z',))
    z_var.units = 'meters (m)'
    z_var.long_name = 'Height'
    z_var[:] = coords[2]

    # Variable Creation (Can be done externally)
    if var_props is not None:

        # Fill variable, as float or enum type
        if enum_dict is None:
            # Build Variable w/ compression test
            if len(var_props) > 4:
                var = ncf.createVariable(var_props[0], var_props[3], ('time','x','y','z'), zlib=True, least_significant_digit=var_props[4])
            else:
                var = ncf.createVariable(var_props[0], var_props[3], ('time','x','y','z'), zlib=True)
            var.units = var_props[1]
            var.long_name = var_props[2]

        # I don't really care about compression for uint8 data type
        else:
            shell_type = ncf.createEnumType(np.uint8, 'shell_type', enum_dict)
            var = ncf.createVariable(var_props[0], shell_type, ('time','x','y','z'))
            var.units = var_props[1]
            var.long_name = var_props[2]

    ncf.close()


def build_ragged_netCDF(path, coords):
    '''
    
    '''

    # Build netCDF to write
    try: ncf.close()
    except: pass
    ncf = Dataset(path, mode='w',format='NETCDF4',fill_value=fill)



#### PACKAGE RELIANT INTERPOLATION #### - VERY ANNOYING

# # Create the actual vector coordinates in x,y,z
# # Only one corrected coordinate is needed for each
# # vector field
# vector_coordinates = []
# for dim, coord in enumerate(sam.coordinates):

#     # Strings will be a non-constant coordinate (z)
#     # Shift forward to staggered location
#     if isinstance(sam.dxi[dim], str):
#         midpoints = 0.5*(coord[1:] + coord[:-1]) # Rectilinear uses midpoint of levels
#         midpoints = np.append(midpoints, 2*coord[-1] - midpoints[-1]) # Interp final point
#         staggered_coord  = midpoints
#     else:
#         staggered_coord = coord + 0.5*sam.dxi[dim]

#     vector_coordinates.append(staggered_coord)

# sam.U = sam.U.assign_coords(z=vector_coordinates[2])
# #sam.V.assign_coords(y=vector_coordinates[1])
# #sam.W.assign_coords(z=vector_coordinates[2])

# #plt.plot(sam.U.U.data.mean(axis=(0,1)),sam.U.z.data)

# print('Interpolating...')
# sam.U = sam.U.interp(x=sam.coordinates[0], method='linear')
# #sam.V.interp(y=sam.coordinates[1])
# #sam.W.interp(z=sam.coordinates[2])
# sam.U.U.shape

# #plt.plot