Module deeplenstronomy.visualize
Functions to visualize images.
Expand source code
"""Functions to visualize images."""
from astropy.visualization import make_lupton_rgb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
def _no_stretch(val):
return val
def view_image(image, stretch_func=_no_stretch, **imshow_kwargs):
"""
Plot an image.
Args:
image (array): a 2-dimensional array of pixel values OR a list-like object of 2-dimensional arrays of pixel values
stretch_func (func, optional, default=pass): stretching function to apply to pixel values (e.g. np.log10)
imshow_kwargs (dict): dictionary of keyword arguments and their values to pass to matplotlib.pyplot.imshow
"""
if len(np.shape(image)) > 2:
#multi-band mode
fig, axs = plt.subplots(1, np.shape(image)[0])
for index, single_band_image in enumerate(image):
axs[index].imshow(stretch_func(single_band_image), **imshow_kwargs)
axs[index].set_xticks([], [])
axs[index].set_yticks([], [])
fig.tight_layout()
plt.show(block=True)
plt.close()
else:
#simgle-band mode
plt.figure()
plt.imshow(stretch_func(image), **imshow_kwargs)
plt.xticks([], [])
plt.yticks([], [])
plt.show(block=True)
plt.close()
return
def view_image_rgb(images, Q=2.0, stretch=4.0, **imshow_kwargs):
"""
Merge images into a single RGB image. This function assumes the image array
is ordered [g, r, i].
Args:
images (List[np.array]): a list of at least 3 2-dimensional arrays of pixel values corresponding to different photometric bandpasses
imshow_kwargs (dict): dictionary of keyword arguments and their values to pass to matplotlib.pyplot.imshow
"""
assert len(images) > 2, "3 images are needed to generate an RGB image"
rgb = make_lupton_rgb(images[2],
images[1],
images[0],
Q=Q, stretch=stretch)
plt.figure()
plt.imshow(rgb, **imshow_kwargs)
plt.xticks([], [])
plt.yticks([], [])
plt.show(block=True)
plt.close()
return
def view_corner(metadata, labels, hist_kwargs={}, hist2d_kwargs={}, label_kwargs={}):
"""
Show a corner plot of the columns in a DataFrame.
Args:
metadata (pd.DataFrame): A pandas DataFrame containing the metadata to visualize
labels (dict): A dictionary mapping column names to axis labels
hist_kwargs (dict): keyword arguments to pass to matplotlib.pyplot.hist
hist2d_kwargs (dict): keyword arguments to pass to matplotlib.pyplot.hist2d
label_kwargs (dict): keyword arguments to pass to matplotlib.axes.Axes.set_xlabel (and ylabel)
Raises:
KeyError: if one or more of the columns are not present in the metadata
TypeError: if metadata is not a pandas DataFrame
TypeError: if labels is not a dict
"""
if not isinstance(metadata, pd.DataFrame):
raise TypeError("first argument must be a pandas DataFrame")
if not isinstance(labels, dict):
raise TypeError("second argument must be a list")
if any([x not in metadata.columns for x in labels]):
raise KeyError("One or more passed columns is not present in the metadata")
fig, axs = plt.subplots(len(labels), len(labels), figsize=(14,14))
for row, row_label in enumerate(labels.keys()):
for col, col_label in enumerate(labels.keys()):
if row == col:
# hist
axs[row, col].hist(metadata[row_label].values, **hist_kwargs)
elif row > col:
# hist2d
axs[row, col].hist2d(metadata[col_label].values,
metadata[row_label].values, **hist2d_kwargs)
else:
axs[row, col].set_visible(False)
if row == len(labels) -1:
axs[row, col].set_xlabel(labels[col_label], **label_kwargs)
if col == 0 and row != 0:
axs[row, col].set_ylabel(labels[row_label], **label_kwargs)
fig.tight_layout()
plt.show()
plt.close()
Functions
def view_corner(metadata, labels, hist_kwargs={}, hist2d_kwargs={}, label_kwargs={})
-
Show a corner plot of the columns in a DataFrame.
Args
metadata
:pd.DataFrame
- A pandas DataFrame containing the metadata to visualize
labels
:dict
- A dictionary mapping column names to axis labels
hist_kwargs
:dict
- keyword arguments to pass to matplotlib.pyplot.hist
hist2d_kwargs
:dict
- keyword arguments to pass to matplotlib.pyplot.hist2d
label_kwargs
:dict
- keyword arguments to pass to matplotlib.axes.Axes.set_xlabel (and ylabel)
Raises
KeyError
- if one or more of the columns are not present in the metadata
TypeError
- if metadata is not a pandas DataFrame
TypeError
- if labels is not a dict
Expand source code
def view_corner(metadata, labels, hist_kwargs={}, hist2d_kwargs={}, label_kwargs={}): """ Show a corner plot of the columns in a DataFrame. Args: metadata (pd.DataFrame): A pandas DataFrame containing the metadata to visualize labels (dict): A dictionary mapping column names to axis labels hist_kwargs (dict): keyword arguments to pass to matplotlib.pyplot.hist hist2d_kwargs (dict): keyword arguments to pass to matplotlib.pyplot.hist2d label_kwargs (dict): keyword arguments to pass to matplotlib.axes.Axes.set_xlabel (and ylabel) Raises: KeyError: if one or more of the columns are not present in the metadata TypeError: if metadata is not a pandas DataFrame TypeError: if labels is not a dict """ if not isinstance(metadata, pd.DataFrame): raise TypeError("first argument must be a pandas DataFrame") if not isinstance(labels, dict): raise TypeError("second argument must be a list") if any([x not in metadata.columns for x in labels]): raise KeyError("One or more passed columns is not present in the metadata") fig, axs = plt.subplots(len(labels), len(labels), figsize=(14,14)) for row, row_label in enumerate(labels.keys()): for col, col_label in enumerate(labels.keys()): if row == col: # hist axs[row, col].hist(metadata[row_label].values, **hist_kwargs) elif row > col: # hist2d axs[row, col].hist2d(metadata[col_label].values, metadata[row_label].values, **hist2d_kwargs) else: axs[row, col].set_visible(False) if row == len(labels) -1: axs[row, col].set_xlabel(labels[col_label], **label_kwargs) if col == 0 and row != 0: axs[row, col].set_ylabel(labels[row_label], **label_kwargs) fig.tight_layout() plt.show() plt.close()
def view_image(image, stretch_func=<function _no_stretch>, **imshow_kwargs)
-
Plot an image.
Args
image
:array
- a 2-dimensional array of pixel values OR a list-like object of 2-dimensional arrays of pixel values
stretch_func
:func
, optional, default=pass
- stretching function to apply to pixel values (e.g. np.log10)
imshow_kwargs
:dict
- dictionary of keyword arguments and their values to pass to matplotlib.pyplot.imshow
Expand source code
def view_image(image, stretch_func=_no_stretch, **imshow_kwargs): """ Plot an image. Args: image (array): a 2-dimensional array of pixel values OR a list-like object of 2-dimensional arrays of pixel values stretch_func (func, optional, default=pass): stretching function to apply to pixel values (e.g. np.log10) imshow_kwargs (dict): dictionary of keyword arguments and their values to pass to matplotlib.pyplot.imshow """ if len(np.shape(image)) > 2: #multi-band mode fig, axs = plt.subplots(1, np.shape(image)[0]) for index, single_band_image in enumerate(image): axs[index].imshow(stretch_func(single_band_image), **imshow_kwargs) axs[index].set_xticks([], []) axs[index].set_yticks([], []) fig.tight_layout() plt.show(block=True) plt.close() else: #simgle-band mode plt.figure() plt.imshow(stretch_func(image), **imshow_kwargs) plt.xticks([], []) plt.yticks([], []) plt.show(block=True) plt.close() return
def view_image_rgb(images, Q=2.0, stretch=4.0, **imshow_kwargs)
-
Merge images into a single RGB image. This function assumes the image array is ordered [g, r, i].
Args
images
:List[np.array]
- a list of at least 3 2-dimensional arrays of pixel values corresponding to different photometric bandpasses
imshow_kwargs
:dict
- dictionary of keyword arguments and their values to pass to matplotlib.pyplot.imshow
Expand source code
def view_image_rgb(images, Q=2.0, stretch=4.0, **imshow_kwargs): """ Merge images into a single RGB image. This function assumes the image array is ordered [g, r, i]. Args: images (List[np.array]): a list of at least 3 2-dimensional arrays of pixel values corresponding to different photometric bandpasses imshow_kwargs (dict): dictionary of keyword arguments and their values to pass to matplotlib.pyplot.imshow """ assert len(images) > 2, "3 images are needed to generate an RGB image" rgb = make_lupton_rgb(images[2], images[1], images[0], Q=Q, stretch=stretch) plt.figure() plt.imshow(rgb, **imshow_kwargs) plt.xticks([], []) plt.yticks([], []) plt.show(block=True) plt.close() return