Source code for fiesta.disperse

######################################################################
# FIlamEntary STructure Analysis (FIESTA) -- Pre-release
# Authors: Jiten Dhandha, Zoe Faes and Rowan Smith
# Last edited: September 2022
######################################################################

"""
This module contains classes for dealing with filamentary networks 
from DisPerSE, and a variety of very useful algorithms such as
readjusting filament spines, characterizing filament lengths and masses,
identifying junctions/hubs in the network and more!
"""

######################################################################
#                            LIBRARIES                               #
######################################################################

#Standard libs
import copy
#Numpy
import numpy as np
#Scipy
from scipy import stats
from scipy.interpolate import splprep, splev
from scipy.spatial import KDTree
from scipy.spatial.distance import cdist
from scipy.spatial.transform import Rotation
#Astropy
from astropy import units as u
#Matplotlib
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

#FIESTA
from fiesta import units as ufi
from fiesta import _utils as utils

######################################################################
#                        INTERNAL FUNCTIONS                          #
######################################################################

def _magnitude(v):
    return np.sqrt(v[0]**2+v[1]**2+v[2]**2)

def _normalize(v):
    return v/_magnitude(v)

#Returns the rotation matrix for the angle between the given vector and [0,0,1]
def _get_rotation_matrix(v):
    n = _normalize(v)
    theta = np.arccos(n[2])
    u = _normalize(np.cross(n,[0,0,1]))
    rot = Rotation.from_rotvec(theta * u)
    return rot

def _get_spheres(points, grid, SPHERE_RADIUS, grid_tree=None):

        #Initialize arrays
        all_sphere_ids = []
        all_sphere_points = []

        #STEP 1: Creating the KDTree from grid (if not pre-evaulated and passed)
        if grid_tree is None:
            grid_tree = KDTree(grid)

        #STEP 2: Querying all cells within a sphere of radius SPHERE_RADIUS
        all_sphere_ids = grid_tree.query_ball_point(points, r=SPHERE_RADIUS)
        for sphere_ids in all_sphere_ids:
            all_sphere_points.append(grid[sphere_ids])

        return all_sphere_ids, all_sphere_points

def _get_cylinders(points, grid, CYL_RADIUS):

        ids = np.arange(len(grid),dtype=int)

        #Initialize arrays
        vectors = []
        rotation_matrices =[]
        all_cylinder_ids = []
        all_cylinder_points = []
        
        #STEP 1: Find vectors between successive points
        vectors = np.diff(points,axis=0)
        
        for i in range(len(points)-1):

            #STEP 2: Extract a cubical region of side 4*CYL_RADIUS around point
            SEARCH_RADIUS = 2*CYL_RADIUS
            REGION_MASK = ((grid[:,0] > points[i,0]-SEARCH_RADIUS) &
                           (grid[:,0] < points[i,0]+SEARCH_RADIUS) &
                           (grid[:,1] > points[i,1]-SEARCH_RADIUS) &
                           (grid[:,1] < points[i,1]+SEARCH_RADIUS) &
                           (grid[:,2] > points[i,2]-SEARCH_RADIUS) &
                           (grid[:,2] < points[i,2]+SEARCH_RADIUS))
            region_points = grid[REGION_MASK]
            region_ids = ids[REGION_MASK]

            #STEP 3: Translate and rotate the region and filament to the origin
            rot = _get_rotation_matrix(vectors[i])
            points_at_origin = rot.apply(points-points[i])
            region_at_origin = rot.apply(region_points-points[i])
            rotation_matrices.append(rot)
        
            #STEP 4: Querying all cells within a cylinder of radius CYL_RADIUS
            CYLINDER_MASK = ((region_at_origin[:,0]**2 + region_at_origin[:,1]**2 <= CYL_RADIUS**2) &
                             (region_at_origin[:,2] > min(points_at_origin[i,2],points_at_origin[i+1,2])) &
                             (region_at_origin[:,2] < max(points_at_origin[i,2],points_at_origin[i+1,2])))
            cylinder_ids = region_ids[CYLINDER_MASK]
            cylinder_points = region_points[CYLINDER_MASK]
            
            #Adding to array
            all_cylinder_ids.append(cylinder_ids)
            all_cylinder_points.append(cylinder_points)

        return vectors, rotation_matrices, all_cylinder_ids, all_cylinder_points

######################################################################
#                         CRITICAL POINT                             #
######################################################################

[docs]class CriticalPoint: """ Class to deal with critical points output by DisPerSE. Attributes ---------- idx : `int` Unique index for the critical point in the `~fiesta.disperse.Network`. pos : `~astropy.units.Quantity` Position :math:`(x,y,z)` of the critical point. Defaults to units of ``u.pix``. nfil : `int` Number of filaments connected to the critical point. scale : `str` Either ``pixel`` if the grid is not scaled to the source AREPO snapshot, else ``physical``. """ def __init__(self): self.idx = None self.pos = None self.nfil = None self.scale = None
###################################################################### # FILAMENT # ######################################################################
[docs]class Filament: """ Class to deal with filaments output by DisPerSE. Attributes ---------- idx : `int` Unique index for the filament in the `~fiesta.disperse.Network`. cps : pair of `~fiesta.disperse.CriticalPoint` Critical point at the start and end of the filament. nsamp : `int` Number of sampling points of the filament. samps : `~astropy.units.Quantity` Array of position :math:`(x,y,z)` of each sampling point of the filament. Defaults to units of ``u.pix``. scale : `str` Either ``pixel`` if the grid is not scaled to the source AREPO snapshot, else ``physical``. length : `~astropy.units.Quantity` Length of the filament. Set only after calling `~fiesta.disperse.Filament.calc_length`. radius : `~astropy.units.Quantity` Radius of the filament. Fixed value set during `~fiesta.disperse.Filament.characterize_filament`. mass : `~astropy.units.Quantity` Mass of the filament. Set only after calling `~fiesta.disperse.Filament.calc_mass`. arepo_ids : `list` of `int`'s List of indices of `~fiesta.arepo.ArepoVoronoiGrid` cells that are "part" of the filament. Set only after calling `~fiesta.disperse.Filament.characterize_filament`. """ def __init__(self): #Core variables self.idx = None self.cps = None self.nsamp = None self.samps = None self.scale = None #NOTE: The samps contain the critical points already (as first and last element) #Tools self._filfunc = None self.arepo_ids = None #Properties self.length = None self.radius = None self.mass = None ###################################################################### # filfunc # ###################################################################### #The filfunc is a function that parametrizes the 1D curve in 3D space: #f(s) = (x,y,z) where s = [0,1] -- the normalized distance along the filament. def _set_filfunc(self): #Since you can always fit a N-1 degree polynomial to N points if(self.nsamp>3): degree = 3 else: degree = self.nsamp-1 #Note the smoothing s=0 so that the fit curve passes through all sampling points if(self.scale=="physical"): internal_unit = u.cm elif(self.scale=="pixel"): internal_unit = u.pix points = self.samps.to_value(internal_unit) func, _ = splprep(points.T,s=0,k=degree) self._filfunc = func
[docs] def get_points(self, s, unit=None): """ Function to sample any point on the filament spine. A *filament function* is one that parameterizes the 1D filament curve in 3D space, defined as :math:`\mathbf{f}(s) = (x,y,z)` where :math:`s \in [0,1]` is the normalized distance of a point along the filament spine and :math:`(x,y,z)` is the position vector of that point in space. Since DisPerSE only outputs a discrete set of sampling points, fitting :math:`\mathbf{f}(s)` allows one to sample the filament continuously or as finely as required. In this implementation, the function is a cubic B-spline fit using `scipy.interpolate.splprep`, with zero-smoothing to ensure it passes through input points. Attributes ---------- s : `list` or `~numpy.ndarray` of `float`'s Array of normalized distances along the filament :math:`s \in [0,1]`. unit : `~astropy.units.Unit` Unit to output the points in. If ``None`` (default), takes the value ``u.pix`` if ``scale='pixel'``, else ``u.cm`` if ``scale='physical'``. Returns ---------- points : `~astropy.units.Quantity` Array of 3d position vectors corresponding to the points at the given normalized distances :math:`s` along the filament. """ if(self.scale=="physical"): internal_unit = u.cm elif(self.scale=="pixel"): internal_unit = u.pix if unit is None: unit = internal_unit utils.check_unit(unit, internal_unit) points = (splev(s,self._filfunc) << internal_unit).T return points.to(unit)
###################################################################### # Getters # ######################################################################
[docs] def get_length(self, unit=None): """ Returns `~fiesta.disperse.Filament.length`. Parameters ---------- unit : `~astropy.units.Unit`, optional Unit to output the result in. If ``None`` (default), takes the value ``u.pix`` if ``scale='pixel'``, else ``u.cm`` if ``scale='physical'``. """ if self.length is None: raise ValueError(utils._prestring() + "Length needs to be calculated first using calc_length() function.") else: if(self.scale=="physical"): internal_unit = u.cm elif(self.scale=="pixel"): internal_unit = u.pix if unit is None: unit = internal_unit utils.check_unit(unit, internal_unit) return self.length.to(unit)
[docs] def get_arepo_ids(self): """ Returns `~fiesta.disperse.Filament.arepo_ids`. """ if self.arepo_ids is None: raise ValueError(utils._prestring() + "Filament needs to be characterized first using characterize_filament() function.") else: return self.arepo_ids
[docs] def get_radius(self, unit=u.cm): """ Returns `~fiesta.disperse.Filament.radius`. Parameters ---------- unit : `~astropy.units.Unit`, optional Unit to output the result in. Default value is ``u.cm``. """ if self.radius is None: raise ValueError(utils._prestring() + "Filament needs to be characterized first using characterize_filament() function.") else: utils.check_unit(unit, u.cm) return self.radius.to(unit)
[docs] def get_mass(self, unit=u.g): """ Returns `~fiesta.disperse.Filament.mass`. Parameters ---------- unit : `~astropy.units.Unit`, optional Unit to output the result in. Default value is ``u.g``. """ if self.mass is None: raise ValueError(utils._prestring() + "Mass needs to be calculated first using calc_mass() function.") else: utils.check_unit(unit, u.g) return self.mass.to(unit)
###################################################################### # ANALYZING THE FILAMENT # ######################################################################
[docs] def correct_spine(self, avg, sphere_radius, method='center of density', verbose=False, plot_correction=False, length_unit=u.cm): """ Function to correct the spine of a filament, based on source `~fiesta.arepo.ArepoVoronoiGrid`. Since DisPerSE works on a regular cubic grid (i.e., `~fiesta.arepo.ArepoCubicGrid`) which has a finite resolution that washes over local overdensities, it can be useful to correct the spine of the filament, once identified on the coarser regular grid, using the original finer Voronoi grid. The correction is made by querying all AREPO cells within a given ``sphere_radius`` of each sampling point, and then shifting the sampling point based on the ``method`` of correction. Note that the filament must have ``scale='physical'`` to use this functionality. The function updates the `~fiesta.disperse.Filament.samps` and the *filament function* internally! Attributes ---------- avg : `~fiesta.arepo.ArepoVoronoiGrid` Source `~fiesta.arepo.ArepoVoronoiGrid` to use for correcting the spine. sphere_radius : `~astropy.units.Quantity` Spherical radius of correction for each sampling point. method : `str`, optional Either ``'max density'`` or ``'center of density'`` (default). The former shifts the sampling point to the densest cell within the given spherical radius, while the latter calculates and shifts it to the center of density in the sphere. verbose : `bool`, optional If ``True``, prints output during spine correction. Default value is ``False``. plot_correction : `bool`, optional If ``True``, plots filament before and after the spine correction. Useful for checks! Default value is ``False``. length_unit : `~astropy.units.Unit`, optional The unit of length to use for the plot. Default value is ``u.cm``. """ if(verbose): print(utils._prestring() + "Correcting Filament spine for idx = {}".format(self.idx)) if(self.scale!="physical"): raise ValueError(utils._prestring() + "Filament must be in \"physical\" units.") #AREPO data variables AREPO_mass = avg.mass[avg.gas_ids].to_value(u.g) AREPO_ndensity = avg.get_ndensity().to_value(u.cm**-3) AREPO_pos = avg.pos[avg.gas_ids].to_value(u.cm) if avg._tree is None: avg._tree = KDTree(AREPO_pos) AREPO_tree = avg._tree #Getting points old_points = self.samps old_points_val = self.samps.to_value(u.cm) old_spline = self.get_points(np.linspace(0,1,self.nsamp*100,endpoint=True)) utils.check_quantity(sphere_radius, u.cm, "sphere_radius") sphere_radius = sphere_radius.to_value(u.cm) #Initialize arrays all_sphere_ids, all_sphere_points = _get_spheres(old_points_val, AREPO_pos, sphere_radius, AREPO_tree) new_filament = [] for i in range(len(old_points_val)): if all_sphere_points[i].size==0: new_filament.append(old_points_val[i]) continue else: if(method.lower()=='max density'): corr_idx = np.argmax(AREPO_ndensity[all_sphere_ids[i]]) corr_point = all_sphere_points[i][corr_idx] elif(method.lower()=='center of density'): corr_point = np.average(all_sphere_points[i], weights=AREPO_ndensity[all_sphere_ids[i]], axis=0) else: raise ValueError(utils._prestring() + "Invalid method of spine correction.") corr_point = list(corr_point) if corr_point not in new_filament: new_filament.append(corr_point) new_filament = u.Quantity(new_filament,unit=u.cm).to(ufi.AREPO_LENGTH) #Saving the corrected filament self.samps = new_filament self.nsamp = len(new_filament) self._set_filfunc() new_spline = self.get_points(np.linspace(0,1,self.nsamp*100,endpoint=True)) #Plot the correction if plot_correction: fig = plt.figure(figsize=(8,8)) ax = fig.add_subplot(111,projection='3d') utils.check_unit(length_unit, u.cm) ax.set_xlabel(r"$x$ [{}] ".format(length_unit.to_string()),fontsize=15,labelpad=5) ax.set_ylabel(r"$y$ [{}] ".format(length_unit.to_string()),fontsize=15,labelpad=5) ax.set_zlabel(r"$z$ [{}] ".format(length_unit.to_string()),fontsize=15,labelpad=5) ax.scatter(*old_points.to_value(length_unit).T, color='black', s=5) ax.plot(*old_spline.to_value(length_unit).T, color='grey', linewidth=2) ax.scatter(*self.samps.to_value(length_unit).T, color='red', s=5) ax.plot(*new_spline.to_value(length_unit).T, color='gold', linewidth=2) for sphere_points in all_sphere_points: sphere_points = sphere_points << u.cm if(not sphere_points.size == 0): ax.scatter(*sphere_points.to_value(length_unit).T, alpha=0.02)
[docs] def characterize_filament(self, avg, cylinder_radius, npoints=None, verbose=False, plot_filament=False, length_unit=u.cm): """ Function to query all simulation cells around the filament, based on source `~fiesta.arepo.ArepoVoronoiGrid`. The function splits the filament into equidistant ``npoints``, joining them by cylinders of radius ``cylinder_radius``. All AREPO cells within these cylinders are considered "part" of the filament. The indices of the cells are then stored as the instance variable `~fiesta.disperse.Filament.arepo_ids`, and the ``cylinder_radius`` sets the instance variable `~fiesta.disperse.Filament.radius`. Note that double-counting is **not** corrected for in the indices. In order to do this, try `np.unique(np.hstack(arepo_ids))`. Attributes ---------- avg : `~fiesta.arepo.ArepoVoronoiGrid` Source `~fiesta.arepo.ArepoVoronoiGrid` to use for characterizing the filament. cylinder_radius : `float` or `~astropy.units.Quantity` Radius of cylinder to query AREPO cells within. If not a `~astropy.units.Quantity`, assumed to be in ``u.cm`` units. npoints : `int`, optional Number of points to split the filament into. If ``None`` (default), then the sampling points `~fiesta.disperse.Filament.samps` are used. verbose : `bool`, optional If ``True``, prints output during filament characterization. Default value is ``False``. plot_filament : `bool`, optional If ``True``, plots the filament along with all its AREPO cells Useful for checks! Default value is ``False``. length_unit : `~astropy.units.Unit`, optional The unit of length to use for the plot. Default value is ``u.cm``. """ if(verbose): print(utils._prestring() + "Characterizing Filament for idx = {}".format(self.idx)) if(self.scale!="physical"): raise ValueError(utils._prestring() + "Filament must be in \"physical\" units.") #AREPO data variables AREPO_pos = avg.pos[avg.gas_ids].to_value(u.cm) #Getting points if(npoints is None): npoints = self.nsamp points = self.samps else: points = self.get_points(np.linspace(0,1,npoints,endpoint=True)) points_val = points.to_value(u.cm) utils.check_quantity(cylinder_radius, u.cm, "cylinder_radius") self.radius = cylinder_radius.to(u.cm) cylinder_radius = cylinder_radius.to_value(u.cm) #Initialize arrays vectors, rotation_matrices, all_cylinder_ids, all_cylinder_points = _get_cylinders(points_val, AREPO_pos, cylinder_radius) self.arepo_ids = all_cylinder_ids #Plot the filament if plot_filament: fig = plt.figure(figsize=(8,8)) ax = fig.add_subplot(111,projection='3d') utils.check_unit(length_unit, u.cm) ax.set_xlabel(r"$x$ [{}] ".format(length_unit.to_string()),fontsize=15,labelpad=5) ax.set_ylabel(r"$y$ [{}] ".format(length_unit.to_string()),fontsize=15,labelpad=5) ax.set_zlabel(r"$z$ [{}] ".format(length_unit.to_string()),fontsize=15,labelpad=5) ax.scatter(*points.to_value(length_unit).T, color='black', s=5) spline = self.get_points(np.linspace(0,1,npoints,endpoint=True)) ax.plot(*spline.to_value(length_unit).T, color='grey', linewidth=2) for cylinder_points in all_cylinder_points: cylinder_points = cylinder_points << u.cm if(not cylinder_points.size == 0): ax.scatter(*cylinder_points.to_value(length_unit).T, alpha=0.02)
###################################################################### # PHYSICAL PROPERTIES # ######################################################################
[docs] def calc_length(self, verbose=False): """ Calculates the length of the filament, either in pixels or physical units depending on `~fiesta.disperse.Filament.scale`, by finely integrating the Euclidean distance between very close points on the filament. Attributes ---------- verbose : `bool`, optional If ``True``, prints output during length calculation. Default value is ``False``. """ if(verbose): print(utils._prestring() + "Calculating Filament length for idx = {}".format(self.idx)) #Summing Euclidean distance between very close points spline = self.get_points(np.linspace(0,1,self.nsamp*500,endpoint=True)) self.length = np.sqrt((np.diff(spline,axis=0) ** 2).sum(axis=1)).sum() << spline.unit if(verbose): print(utils._prestring() + "The length of the filament is {}".format(self.length))
[docs] def calc_mass(self, avg, verbose=False): """ Calculates the mass of the filament by summing the mass of all cells given by `~fiesta.disperse.Filament.arepo_ids` (corrected for double-counting). Note that this requires calling `~fiesta.disperse.Filament.characterize_filament` first to set `~fiesta.disperse.Filament.arepo_ids`. Attributes ---------- avg : `~fiesta.arepo.ArepoVoronoiGrid` Source `~fiesta.arepo.ArepoVoronoiGrid` to use for calculating mass of the filament. verbose : `bool`, optional If ``True``, prints output during mass calculation. Default value is ``False``. """ if verbose: print(utils._prestring() + "Calculating Filament mass for idx = {}".format(self.idx)) if self.arepo_ids is None: raise ValueError(utils._prestring() + "Filament needs to be characterized first using characterize_filament() function.") unique_ids = np.unique(np.hstack(self.arepo_ids)) self.mass = np.sum(avg.mass[unique_ids]) if(verbose): print(utils._prestring() + "The mass of the filament is {}".format(self.mass))
[docs] def plot_density_profile(self, avg, length_unit=u.cm, ndensity_unit=u.cm**-3, scatter=True, contours=None, save=None, **kwargs): """ Plot mass-weighted 2d histogram projection of the AREPO Voronoi grid along an axis. Parameters ---------- avg : `~fiesta.arepo.ArepoVoronoiGrid` Source `~fiesta.arepo.ArepoVoronoiGrid` to use for plotting the profile length_unit : `~astropy.units.Unit`, optional The unit of length to use. If ``None`` (default), takes the value ``u.pix`` if ``scale='pixel'``, else ``u.cm`` if ``scale='physical'``. ndensity_unit : `~astropy.units.Unit`, optional The unit of number density to use. Default value is ``u.cm**-3``. scatter: `bool`, optional If ``True`` (default), plots the scatter of radius versus density for each cell. Its color can be controlled through ``color`` argument in ``**kwargs``. contours: `int`, optional Number of contours, estimated through `scipy.stats.gaussian_kde`, of the scatter. If ``None`` (default), no contour plot is made. Its color can be controlled through ``cmap`` argument in ``**kwargs``. save : str, optional The name of the file to save the plot as. If ``None`` (default), plot is not saved. **kwargs : dict, optional Additional *matplotlib*-based keyword arguments to control finer details of the plot. Returns ------- fig : `~matplotlib.figure.Figure` Main `~matplotlib.figure.Figure` instance. """ if self.arepo_ids is None: raise ValueError(utils._prestring() + "Filament needs to be characterized first using characterize_filament() function.") #Shedding units for ease of use points = avg.pos.to_value(length_unit) spline = self.get_points(np.linspace(0,1,self.nsamp*500)).to_value(length_unit) ndensity = avg.get_ndensity().to_value(ndensity_unit) mass = avg.mass.to_value(u.g) #Main figure fig = plt.figure(figsize=(8,8)) if "figure" in kwargs: plt.setp(fig,**kwargs["figure"]) ax = fig.add_subplot(111) #Axes scales ax.set_yscale('log') if "xscale" in kwargs: ax.set_xscale(**kwargs["xscale"]) if "yscale" in kwargs: ax.set_yscale(**kwargs["yscale"]) #Axes ticks ax.xaxis.set_tick_params(which='major', width=1, length=5, labelsize=15) ax.xaxis.set_tick_params(which='minor', width=1, length=2.5, labelsize=10) ax.yaxis.set_tick_params(which='major', width=1, length=5, labelsize=15) ax.yaxis.set_tick_params(which='minor', width=1, length=2.5, labelsize=10) if "xtick_params" in kwargs: ax.xaxis.set_tick_params(**kwargs["xtick_params"]) if "ytick_params" in kwargs: ax.yaxis.set_tick_params(**kwargs["ytick_params"]) if "tick_params" in kwargs: ax.tick_params(**kwargs["tick_params"]) #Axes labels ax.set_xlabel(r"Radial distance $r$ [{}]".format(length_unit.to_string()),fontsize=15) ax.set_ylabel(r"Number density $n$ [{}]".format(ndensity_unit.to_string()),fontsize=15) if "xlabel" in kwargs: ax.set_xlabel(**kwargs["xlabel"]) if "ylabel" in kwargs: ax.set_ylabel(**kwargs["ylabel"]) #Figure title ax.set_title("",fontsize=15) if "title" in kwargs: ax.set_title(**kwargs["title"]) ############### Plotting start ################ #Colors! if "color" in kwargs: color1 = kwargs["color"] else: color1 = 'black' if "cmap" in kwargs: cmap2 = copy.copy(mpl.cm.get_cmap(kwargs["cmap"])) else: cmap2 = copy.copy(mpl.cm.get_cmap('Blues')) alpha1 = 0.8 alpha2 = 0.8 def density_estimation(m1, m2, weights=None): #see https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.gaussian_kde.html X, Y = np.mgrid[min(m1):max(m1):100j, min(m2):max(m2):100j] positions = np.vstack([X.ravel(), Y.ravel()]) values = np.vstack([m1, m2]) kernel = stats.gaussian_kde(values,weights=weights) Z = np.reshape(kernel(positions).T, X.shape) return X, Y, Z #Getting unique AREPO ids unique_ids = np.unique(np.hstack(self.arepo_ids)) points = points[unique_ids] ndensity = ndensity[unique_ids] mass = mass[unique_ids] #Finding the radial distance of points and contour of scatter radius = cdist(points,spline).min(axis=1) X, Y, Z = density_estimation(radius,np.log10(ndensity)) #Note: the logarithm #Plotting now if scatter: scatter = ax.scatter(radius, ndensity, s=0.1, c=color1, alpha=alpha1, zorder=1) if contours is not None: contourf = ax.contourf(X, np.power(10,Y), Z, cmap=cmap2, alpha=alpha2, levels=contours, zorder=0) ############### Plotting end ################ #Axes limits if "xlim" in kwargs: ax.set_xlim(**kwargs["xlim"]) if "ylim" in kwargs: ax.set_ylim(**kwargs["ylim"]) #Text if "text" in kwargs: ax.text(**kwargs["text"],transform=ax.transAxes) if save is not None: fig.savefig(save, bbox_inches='tight', dpi=100) return fig
###################################################################### # NETWORK # ######################################################################
[docs]class Network: """ Class to deal with filament network output by DisPerSE. Attributes ---------- cps : `list` of `~fiesta.disperse.CriticalPoint`'s Array of `~fiesta.disperse.CriticalPoint`'s in the network. ncps : int Number of critical points in the network. fils : `list` of `~fiesta.disperse.Filament`'s Array of `~fiesta.disperse.Filament`'s in the network. nfils : `int` Number of filaments in the network scale : `str` Either ``pixel`` if the grid is not scaled to the source AREPO snapshot, else ``physical``. Parameters ---------- file_path : `str` File path of the DisPerSE output (currently only supports ASCII format). verbose : `bool`, optional If ``True``, prints output during file reading. Default value is ``False``. """ ###################################################################### # CONSTRUCTORS # ###################################################################### def __init__(self, file_path, verbose=False): #Assumes ASCII format (see http://www2.iap.fr/users/sousbie/web/html/indexbea5.html?post/NDskl_ascii-format) if(verbose): print(utils._prestring() + "Loading Network from DisPerSE ASCII file from \"{}\" ...".format(file_path)) #FIRST STORING THE DIFFERENT SECTIONS OF THE ASCII OUTPUT with open(file_path) as f: text = f.read() stop1 = text.rfind("[CRITICAL POINTS]") stop2 = text.rfind("[FILAMENTS]") stop3 = text.rfind("[CRITICAL POINTS DATA]") stop4 = text.rfind("[FILAMENTS DATA]") critical_points_text = text[stop1:stop2] filaments_text = text[stop2:stop3] critical_data_text = text[stop3:stop4] filaments_data_text = text[stop4:] #READING CRITICAL POINT DATA critical_points = [] critical_points_text_array = critical_points_text.splitlines()[2:] for i in range(len(critical_points_text_array)): line = critical_points_text_array[i] if(not line[0]==' '): l = line.split(' ') #Creating the object cp = CriticalPoint() cp.idx = len(critical_points) #Setting 'idx' here cp.pos = [float(l[3]),float(l[2]),float(l[1])] << u.pix #NOTE: Notice the order of l[1], l[2] and l[3]. This is the discrepancy in Disperse coordinates and AREPO. cp.nfil = int(critical_points_text_array[i+1]) cp.scale = "pixel" #Adding to array critical_points.append(cp) critical_points = np.array(critical_points) #READING FILAMENT DATA filaments = [] filaments_text_array = filaments_text.splitlines()[2:] for i in range(len(filaments_text_array)): line = filaments_text_array[i] if(not line[0]==' '): l = line.split(' ') #Creating the object fil = Filament() fil.idx = len(filaments) fil.nsamp = int(l[2]) fil.cps = np.array([critical_points[int(l[0])],critical_points[int(l[1])]]) fil.scale = "pixel" fil.samps = [] while(True): if(i==len(filaments_text_array)-1): break else: i+=1 line = filaments_text_array[i] if(not line[0]==' '): i-=1 fil.samps = fil.samps << u.pix fil._set_filfunc() #Adding to array filaments.append(fil) break else: l = line.split(' ') pos = [float(l[3]),float(l[2]),float(l[1])] #NOTE: Notice the order of l[1], l[2] and l[3]. This is the discrepancy in Disperse coordinates and AREPO. fil.samps.append(pos) filaments = np.array(filaments) if(verbose): print(utils._prestring() + "Completed loading Network from DisPerSE ASCII file from \"{}\" .".format(file_path)) self.cps = critical_points self.ncps = len(critical_points) self.fils = filaments self.nfils = len(filaments) self.scale = "pixel" if(verbose): print(utils._prestring() + "Note: The Network has {} filaments and {} critical points.".format(self.nfils,self.ncps)) ###################################################################### # ACCESS FUNCTIONS # ######################################################################
[docs] def get_filament(self, idx): """ Function to fetch a filament of specific ID. Parameters ---------- idx : `int` ID of the filament to fetch. Returns ---------- fil : `~fiesta.disperse.Filament` Filament corresponding to the given ID. """ for fil in self.fils: if(fil.idx == idx): return fil print(utils._prestring() + "No filament with idx = {}.".format(idx))
[docs] def get_cp(self, idx): """ Function to fetch a critical point of specific ID. Parameters ---------- idx : `int` ID of the critical point to fetch. Returns ---------- fil : `~fiesta.disperse.CriticalPoint` Critical point corresponding to the given ID. """ for cp in self.cps: if(cp.idx == idx): return cp print(utils._prestring() + "No critical point with idx = {}.".format(idx))
###################################################################### # RESCALE TO ORIGINAL AREPO # ######################################################################
[docs] def set_scale(self, acg): """ Function to set the physical scale of the network by passing the `~fiesta.arepo.ArepoCubicGrid` that was used as input for DisPerSE. Parameters ---------- acg : `~fiesta.arepo.ArepoCubicGrid` Source `~fiesta.arepo.ArepoCubicGrid` to use for setting the scale. """ if(self.scale=="physical"): print(utils._prestring()+"Already in \"physical\" units.") return if(acg.scale=="physical"): self.scale = acg.scale #Scale critical points for cp in self.cps: cp.scale = acg.scale cp.pos = u.Quantity([acg.xmin + cp.pos[0].value * acg.xlength, acg.ymin + cp.pos[1].value * acg.ylength, acg.zmin + cp.pos[2].value * acg.zlength]) #Scale filaments for fil in self.fils: fil.scale = acg.scale #No need to change the cps here since they are already changed above fil.samps = u.Quantity([acg.xmin + fil.samps[:,0].value * acg.xlength, acg.ymin + fil.samps[:,1].value * acg.ylength, acg.zmin + fil.samps[:,2].value * acg.zlength]).T #Resetting the fil function fil._set_filfunc() else: raise ValueError(utils._prestring() + "ArepoCubicGrid must be in \"physical\" units.")
###################################################################### # PROPERTIES OF THE NETWORK # ######################################################################
[docs] def calc_filament_lengths(self, verbose=False): """ Calculate the length of all filaments in the network. Works simply by calling `~fiesta.disperse.Filament.calc_length` for each filament. Attributes ---------- avg : `~fiesta.arepo.ArepoVoronoiGrid` Source `~fiesta.arepo.ArepoVoronoiGrid` to use for calculating mass of the filament. verbose : `bool`, optional If ``True``, prints output during length calculation. Default value is ``False``. """ if(verbose): print(utils._prestring() + "Completed calculating length of all filaments in Network...") for fil in self.fils: fil.calc_length(verbose) if(verbose): print(utils._prestring() + "Started calculating length of all filaments in Network...")
[docs] def calc_filament_masses(self, avg, verbose=False): """ Calculate the mass of all filaments in the network. Works simply by calling `~fiesta.disperse.Filament.calc_mass` for each filament. Attributes ---------- avg : `~fiesta.arepo.ArepoVoronoiGrid` Source `~fiesta.arepo.ArepoVoronoiGrid` to use for calculating mass of filaments. verbose : `bool`, optional If ``True``, prints output during mass calculation. Default value is ``False``. """ if(verbose): print(utils._prestring() + "Completed calculating mass of all filaments in Network...") for fil in self.fils: fil.calc_mass(avg, verbose) if(verbose): print(utils._prestring() + "Completed calculating mass of all filaments in Network...")
###################################################################### # FILTERING FILAMENTS # ######################################################################
[docs] def remove_short_filaments(self, min_length, verbose=False): """ Remove filaments below a given minimum length. Useful for discarding spurious filaments below grid resolution. Attributes ---------- min_length : `~astropy.units.Quantity` Length below which all filaments are discarded from the network. verbose : `bool`, optional If ``True``, prints output during removal of short filaments. Default value is ``False``. """ #Shedding units for ease of use if(self.scale=="physical"): unit = u.cm elif(self.scale=="pixel"): unit = u.pix utils.check_quantity(min_length, unit, "min_length") min_length = min_length.to_value(unit) #First filtering filaments fils_to_keep = [] for i in range(self.nfils): fil = self.fils[i] if fil.length is None : raise ValueError(utils._prestring() + "Filament length is None. Need to calculate filament lengths first") if(fil.length.to_value(unit) > min_length): fils_to_keep.append(i) if(verbose): nfils_removed = self.nfils-len(fils_to_keep) print(utils._prestring() + "Removed {} short filaments from Network.".format(nfils_removed)) self.fils = self.fils[fils_to_keep] self.nfils = len(self.fils) #Now updating the critical points corresponding to these cp_idxs_in_fils = [cp.idx for fil in self.fils for cp in fil.cps] cps_to_keep = [] for j in range(self.ncps): cp = self.cps[j] if(cp.idx in cp_idxs_in_fils): cps_to_keep.append(j) cp.nfil = cp_idxs_in_fils.count(cp.idx) if(verbose): ncps_removed = self.ncps-len(cps_to_keep) print(utils._prestring() + "Removed {} critical points from Network.".format(ncps_removed)) self.cps = self.cps[cps_to_keep] self.ncps = len(self.cps)
###################################################################### # PLOTTING FUNCTIONS # ######################################################################
[docs] def plot_network(self, fil_idxs=None, length_unit=None, colors=None, splines=False, cylinders=False, avg=None, save=None, **kwargs): """ Plot the filament network in 3D. Parameters ---------- fil_idxs : `list` of int`'s ID's of the filaments to plot. If ``None`` (default), all filaments are plotted. length_unit : `~astropy.units.Unit`, optional The unit of length to use for the plot. If ``None`` (default), takes the value ``u.pix`` if ``scale='pixel'``, else ``u.cm`` if ``scale='physical'``. colors: `list`, optional A list of *matplotlib*-compatible color strings corresponding to each filament plotted. If ``None`` (default), colors are picked by cycling through *matplotlib* tab10 colour palette. splines : `bool`, optional If ``True``, plots the spline fit (*filament function*) of each filament. Default value is ``False``. cylinders : `bool`, optional If ``True``, plots the AREPO cells that are part of the filament. Default value is ``False``. Requires `~fiesta.disperse.Filament.arepo_ids` to be set. avg : `~fiesta.arepo.ArepoVoronoiGrid`, optional The source `~fiesta.arepo.ArepoVoronoiGrid` snapshot to use for plotting the cells. Omitted if ``cylinders=False``. save : `str`, optional File path to save the plot. If ``None`` (default), plot is not saved. **kwargs : `dict`, optional Additional *matplotlib*-based keyword arguments to control finer details of the plot. Returns ------- fig : `~matplotlib.figure.Figure` Main `~matplotlib.figure.Figure` instance. """ #Checking units if(self.scale=="physical"): internal_unit = u.cm elif(self.scale=="pixel"): internal_unit = u.pix if length_unit is None: length_unit = internal_unit utils.check_unit(length_unit, internal_unit) if fil_idxs is None: fil_idxs = [fil.idx for fil in self.fils] #Figure properties nsols = len(fil_idxs) if colors is None: cmap = plt.cm.tab10 colors = cmap(np.arange(nsols)%cmap.N) #Main figure fig = plt.figure(figsize=(8,8)) if "figure" in kwargs: plt.setp(fig,**kwargs["figure"]) ax = fig.add_subplot(111,projection='3d') #Axes ticks ax.xaxis.set_tick_params(which='major', width=1, length=5, labelsize=15) ax.xaxis.set_tick_params(which='minor', width=1, length=2.5, labelsize=10) ax.yaxis.set_tick_params(which='major', width=1, length=5, labelsize=15) ax.yaxis.set_tick_params(which='minor', width=1, length=2.5, labelsize=10) ax.zaxis.set_tick_params(which='major', width=1, length=5, labelsize=15) ax.zaxis.set_tick_params(which='minor', width=1, length=2.5, labelsize=10) if "xtick_params" in kwargs: ax.xaxis.set_tick_params(**kwargs["xtick_params"]) if "ytick_params" in kwargs: ax.yaxis.set_tick_params(**kwargs["ytick_params"]) if "ztick_params" in kwargs: ax.zaxis.set_tick_params(**kwargs["ytick_params"]) #Axes labels ax.set_xlabel(r"$x$ [{}] ".format(length_unit.to_string()),fontsize=15,labelpad=5) ax.set_ylabel(r"$y$ [{}] ".format(length_unit.to_string()),fontsize=15,labelpad=5) ax.set_zlabel(r"$z$ [{}] ".format(length_unit.to_string()),fontsize=15,labelpad=5) if "xlabel" in kwargs: ax.set_xlabel(**kwargs["xlabel"]) if "ylabel" in kwargs: ax.set_ylabel(**kwargs["ylabel"]) if "zlabel" in kwargs: ax.set_zlabel(**kwargs["zlabel"]) #Figure title ax.set_title("",fontsize=15) if "title" in kwargs: ax.set_title(**kwargs["title"]) ############### Plotting start ################ for fil_idx, c in zip(fil_idxs, colors): fil = self.get_filament(fil_idx) points = fil.samps.to_value(length_unit) ax.plot(*points.T, color=c) ax.scatter(*points.T, s=10, color=c) if(splines): spline = fil.get_points(np.linspace(0,1,fil.nsamp*100,endpoint=True)).to_value(length_unit) ax.plot(*spline.T, color=c, linestyle='--') if(cylinders): if fil.arepo_ids is None: raise ValueError(utils._prestring() + "Filament needs to be characterized first using characterize_filament() function.") if(avg is None): raise ValueError(utils._prestring()+"Need to pass ArepoVoronoiGrid to plot the cylinders!") ncyls = len(fil.arepo_ids) cmap = plt.cm.tab10 cylcolors = cmap(np.arange(ncyls)%cmap.N) for (ids,cylcolor) in zip(fil.arepo_ids,cylcolors): cylpoints = avg.pos[ids].to_value(length_unit) ax.scatter(*cylpoints.T, alpha=0.2, s=10, color=cylcolor) #z=ax.get_zlim() #ax.plot3D(fil.samps.transpose()[0],fil.samps.transpose()[1], z[0], c='grey', alpha=0.5,zorder=-10) #ax.scatter(fil.samps.transpose()[0],fil.samps.transpose()[1], z[0], c='grey', alpha=0.5,s=10,zorder=-10) #ax.set_zlim(z) #x=ax.get_xlim() #ax.plot3D([x[0]]*len(fil.samps),fil.samps.transpose()[1],fil.samps.transpose()[2], c='grey',alpha=0.5,zorder=-10) #ax.scatter([x[0]]*len(fil.samps),fil.samps.transpose()[1],fil.samps.transpose()[2], c='grey',alpha=0.5,zorder=-10) #ax.set_xlim(x) #y=ax.get_ylim() #ax.plot3D(fil.samps.transpose()[0],[y[1]]*len(fil.samps),fil.samps.transpose()[2], c='grey',alpha=0.5,zorder=-10) #ax.scatter(fil.samps.transpose()[0],[y[1]]*len(fil.samps),fil.samps.transpose()[2], c='grey',alpha=0.5,s=10,zorder=-10) #ax.set_ylim(y) ############### Plotting end ################ #Axes limits if "xlim" in kwargs: ax.set_xlim(**kwargs["xlim"]) if "ylim" in kwargs: ax.set_ylim(**kwargs["ylim"]) if "zlim" in kwargs: ax.set_zlim(**kwargs["zlim"]) if save is not None: fig.savefig(save, bbox_inches='tight', dpi=100) return fig