Axes swap between 2D geometry and tally?

This is for anyone who might come across this in the future.

I’ve been looking at this recently and solved it slightly differently. I just figured out which axes that had been swapped and then selected the corresponding indexes. It’s very inspired by your solution but a bit more basic, Shimwell. It’s a part of a post-processing class where we have the DataFrame of the tally, the mesh and mesh size etc:

    def plot_tallies(self): 
          """
          Plots tallies by score using the tally DataFrame. 
          """
          for score in self.tallies['mesh_tally'].scores: 
              mean = self.df['mean'].to_numpy().reshape(self.mesh.dimension) 
                
              axes_to_slice = ['X', 'Y', 'Z']
              for axis in axes_to_slice: 
                  self.plot_slice(mean=mean, axis_to_slice=axis, score=score)

    def plot_slice(self, mean, axis_to_slice, score): 
        """
        Method to plot a mesh tally slice. 

        When we reshape the tally results from a flattened array to a 3D array we 
        get the wrong order of the axes. This method takes that into consideration 
        by selecting the right indexes and allows to plot the proper planes (eg XY, XZ or YZ)

        Args:
            values (np array): 3D np array of values of each cell in the mesh (could be mean or std. dev.)
            axis (str): which axis to plot 
            score (str): which score it is 
        """

        if axis_to_slice == "X":
            bb_index = [1, 2]
            slice_index = int(self.mesh_dimension/2)
            image = mean[:, :, slice_index]
            x_label = "Y [cm]"
            y_label = "Z [cm]"
            end = '_YZ'
        elif axis_to_slice == "Y":
            bb_index = [0, 2]
            slice_index = int(self.mesh_dimension/2)
            image = mean[:, slice_index, :]
            x_label = "X [cm]"
            y_label = "Z [cm]"
            end = '_XZ'
        elif axis_to_slice == 'Z':
            bb_index = [0, 1]
            slice_index = int(self.mesh_dimension/2)
            image = mean[slice_index, :, :]
            x_label = "X [cm]"
            y_label = "Y [cm]"
            end = '_XY'
        else: 
            raise ValueError('Axis needs to be X, Y or Z')

        # Plot and save names
        plot_name = 'mesh_' + score + end
        save_name =  plot_name + '.png'

        # Imshow extent
        left = self.mesh.lower_left[bb_index[0]]
        right = self.mesh.upper_right[bb_index[0]]
        bottom = self.mesh.lower_left[bb_index[1]]
        top = self.mesh.upper_right[bb_index[1]]
        extent = (left, right, bottom, top)
        
        # Plot        
        plt.title(plot_name)
        plt.imshow(image, interpolation='None', cmap='jet', extent=extent, aspect='auto')
        plt.xlabel(x_label)
        plt.ylabel(y_label)
        cbar = plt.colorbar()
        cbar.set_label(score + ' per source particle')
        plt.savefig(save_name)
        plt.close()
1 Like