This is for anyone who might come across this in the future.
I’ve been looking at this recently and solved it slightly differently. I just figured out which axes that had been swapped and then selected the corresponding indexes. It’s very inspired by your solution but a bit more basic, Shimwell. It’s a part of a post-processing class where we have the DataFrame of the tally, the mesh and mesh size etc:
def plot_tallies(self):
"""
Plots tallies by score using the tally DataFrame.
"""
for score in self.tallies['mesh_tally'].scores:
mean = self.df['mean'].to_numpy().reshape(self.mesh.dimension)
axes_to_slice = ['X', 'Y', 'Z']
for axis in axes_to_slice:
self.plot_slice(mean=mean, axis_to_slice=axis, score=score)
def plot_slice(self, mean, axis_to_slice, score):
"""
Method to plot a mesh tally slice.
When we reshape the tally results from a flattened array to a 3D array we
get the wrong order of the axes. This method takes that into consideration
by selecting the right indexes and allows to plot the proper planes (eg XY, XZ or YZ)
Args:
values (np array): 3D np array of values of each cell in the mesh (could be mean or std. dev.)
axis (str): which axis to plot
score (str): which score it is
"""
if axis_to_slice == "X":
bb_index = [1, 2]
slice_index = int(self.mesh_dimension/2)
image = mean[:, :, slice_index]
x_label = "Y [cm]"
y_label = "Z [cm]"
end = '_YZ'
elif axis_to_slice == "Y":
bb_index = [0, 2]
slice_index = int(self.mesh_dimension/2)
image = mean[:, slice_index, :]
x_label = "X [cm]"
y_label = "Z [cm]"
end = '_XZ'
elif axis_to_slice == 'Z':
bb_index = [0, 1]
slice_index = int(self.mesh_dimension/2)
image = mean[slice_index, :, :]
x_label = "X [cm]"
y_label = "Y [cm]"
end = '_XY'
else:
raise ValueError('Axis needs to be X, Y or Z')
# Plot and save names
plot_name = 'mesh_' + score + end
save_name = plot_name + '.png'
# Imshow extent
left = self.mesh.lower_left[bb_index[0]]
right = self.mesh.upper_right[bb_index[0]]
bottom = self.mesh.lower_left[bb_index[1]]
top = self.mesh.upper_right[bb_index[1]]
extent = (left, right, bottom, top)
# Plot
plt.title(plot_name)
plt.imshow(image, interpolation='None', cmap='jet', extent=extent, aspect='auto')
plt.xlabel(x_label)
plt.ylabel(y_label)
cbar = plt.colorbar()
cbar.set_label(score + ' per source particle')
plt.savefig(save_name)
plt.close()