Source code for pynumpm.jsa

# coding=utf-8
"""
Module to simulate the 2D spectrum of a pump field for the simulation of joint spectral amplitudes of nonlinear
processes.

"""
import logging
from scipy.special import hermite, factorial
from scipy.constants import c as _sol
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import enum
import warnings
from typing import Union
from pynumpm import phasematching


[docs]class Process(enum.Enum): """ Enum class used to describe the nonlinear processes. It is used to change the way the pump spectrum is calculated. """ SFG = "SFG. Input is the wavelength1, Output is the wavelength2." DFG = "DFG. Input is the wavelength1, Output is the wavelength2." PDC = "PDC." BWPDC = "BWPDC. Wavelength1 is the forward propagating field, wavelength2 is the backward one." # TODO: Check this
[docs]class Pump(object): """ Pump class. It is used to describe 2D pump functions for the calculation of JSA. The pump is modelled as .. math:: \\alpha(\omega_1, \omega_2) = \mathrm{HG}_n(\omega_1, \omega_2)\cdot \mathrm{exp}\\left\\lbrace -\\frac{(\omega_p - \omega_{p,0})^2}{2\\sigma^2}\\right\\rbrace \cdot \mathrm{Chirp} \cdot \mathrm{filter} Initialize the pump object calling the class and passing a suitable `Process` element. The parameters of the pump must be assigned by the user. The following attributes can be modified: :var pump_spectrum: :var wavelength1: :var pump_centre: :var pump_width: :var pump_temporal_mode: :var pump_chirp: :var pump_delay: :var filter_width: """ def __init__(self, process: Process): """ :param process: Process under investigation. An element of the class `pynumpm.jsa.Process`. :type process: Process """ self.__process = None self.process = process self.__pump_centre = None self.__pump_wavelength2D = None self.__pump_width = None self.__filter_pump = False self.__pump_delay = 0 self.__pump_chirp = 0 self.__pump_temporal_mode = 0 self.__pump_filter_width = None self.__wavelength1 = None self.__wavelength2 = None self.__wavelength1_2D = None self.__wavelength2_2D = None self.__pump_spectrum = None self.__correct_pump_width = None @property def pump_spectrum(self): return self.__pump_spectrum @pump_spectrum.setter def pump_spectrum(self, value: np.ndarray): if not isinstance(value, np.ndarray): raise TypeError("The 'pump_spectrum' must be a numpy.ndarray object.") if value.shape[0] != len(self.wavelength1): raise ValueError("The input pump spectrum is not compatible with the signal wavelength set. " "Make sure pump_spectrum.shape[0] == len(wavelength1)") if value.shape[1] != len(self.wavelength2): raise ValueError("The input pump spectrum is not compatible with the idler wavelength set. " "Make sure pump_spectrum.shape[1] == len(wavelength2)") self.__pump_spectrum = value @property def wavelength1(self): return self.__wavelength1 @wavelength1.setter def wavelength1(self, value): self.__wavelength1 = value @property def wavelength2(self): return self.__wavelength2 @wavelength2.setter def wavelength2(self, value): self.__wavelength2 = value @property def process(self): return self.__process @process.setter def process(self, value): if not isinstance(value, Process): raise TypeError("The type of 'process' must be pynumpm.jsa.Process") self.__process = value @property def pump_centre(self): return self.__pump_centre @pump_centre.setter def pump_centre(self, value): self.__pump_centre = value @property def pump_width(self): return self.__pump_width @pump_width.setter def pump_width(self, value): self.__pump_width = value @property def filter_width(self): return self.__pump_filter_width @filter_width.setter def filter_width(self, value): self.__filter_pump = True self.__pump_filter_width = value @property def pump_delay(self): return self.__pump_delay @pump_delay.setter def pump_delay(self, value): self.__pump_delay = value @property def pump_chirp(self): return self.__pump_chirp @pump_chirp.setter def pump_chirp(self, value): self.__pump_chirp = value @property def pump_temporal_mode(self): return self.__pump_temporal_mode @pump_temporal_mode.setter def pump_temporal_mode(self, value: int): errormsg = "The pump temporal mode has to be non-negative integer" if not isinstance(value, int): raise TypeError(errormsg) if value < 0: raise ValueError(errormsg) self.__pump_temporal_mode = value @property def wavelength1_2D(self): return self.__wavelength1_2D @property def wavelength2_2D(self): return self.__wavelength2_2D def _hermite_mode(self, x: float): """ A normalised Hermite-Gaussian function """ # On 22.11.2017, Matteo changed all the self.pump_width to self.__correct_pump_width # _result = hermite(self.pump_temporal_mode)((self.pump_center - x) / # self.pump_width) * \ # exp(-(self.pump_center - x)**2 / (2 * self.pump_width**2)) /\ # sqrt(factorial(self.pump_temporal_mode) * sqrt(pi) * # 2**self.pump_temporal_mode * self.pump_width) # TODO: Check the correctness of the __correct_pump_width parameter _result = hermite(self.pump_temporal_mode)((self.pump_centre - x) / self.__correct_pump_width) * \ np.exp(-(self.pump_centre - x) ** 2 / (2 * self.__correct_pump_width ** 2)) / \ np.sqrt(factorial(self.pump_temporal_mode) * np.sqrt(np.pi) * 2 ** self.pump_temporal_mode * self.__correct_pump_width) return _result def __set_wavelengths(self): text = "" error = False if self.wavelength1 is None: error = True text += "You need to set wavelength1. " if self.wavelength2 is None: error = True text += "You need to set wavelength2." if error: raise ValueError(text.rstrip()) self.__wavelength1_2D, self.__wavelength2_2D = np.meshgrid(self.wavelength1, self.wavelength2) message = "The pump central wavelength hasn't been set. Inferring its value from the wavelength1/wavelength2 arrays" if self.process == Process.PDC or self.process == Process.BWPDC: self.pump_wavelength = 1.0 / (1.0 / self.__wavelength1_2D + 1.0 / self.__wavelength2_2D) if self.pump_centre is None: warnings.warn(message, UserWarning) self.pump_centre = (self.wavelength1.mean() ** -1 + self.wavelength2.mean() ** -1) ** -1 elif self.process == Process.SFG: self.pump_wavelength = 1.0 / (1.0 / self.__wavelength2_2D - 1.0 / self.__wavelength1_2D) if self.pump_centre is None: warnings.warn(message, UserWarning) self.pump_centre = (self.wavelength2.mean() ** -1 - self.wavelength1.mean() ** -1) ** -1 elif self.process == Process.DFG: self.pump_wavelength = 1.0 / (1.0 / self.__wavelength1_2D - 1.0 / self.__wavelength2_2D) if self.pump_centre is None: warnings.warn(message, UserWarning) self.pump_centre = (self.wavelength1.mean() ** -1 - self.wavelength2.mean() ** -1) ** -1 else: raise NotImplementedError("The process {0} has not been implemented yet.".format(self.process))
[docs] def calculate_pump_spectrum(self): """ This function calculates the pump function. :return: matrix containing the pump function in wavelength1 and wavelength2 frequency plane """ logger = logging.getLogger(__name__) self.__set_wavelengths() # self.pump_width /= 2 * sqrt(log(2)) # self.pump_width = self.pump_width /( 2 * sqrt(log(2))) self.__correct_pump_width = self.pump_width / (2 * np.sqrt(np.log(2))) if self.__filter_pump: self.filter_width *= np.sqrt(2) _filter = np.zeros(np.shape(self.pump_wavelength), float) logger.debug("Pump wavelength: %f", np.shape(self.pump_wavelength)) for i in range(len(self.wavelength1)): logger.debug("Loop index: %d", i) for j in range(len(self.wavelength2)): if self.pump_wavelength[j, i] < self.pump_centre - \ 0.5 * self.filter_width: pass elif self.pump_wavelength[j, i] <= self.pump_centre + \ 0.5 * self.filter_width: _filter[j, i] = 1 else: pass _pump_function = self._hermite_mode(self.pump_wavelength) * \ np.exp(1j * 2 * np.pi * _sol / self.pump_wavelength * self.pump_delay) * \ np.exp(1j * (2 * np.pi * _sol / self.pump_wavelength) ** 2 * self.pump_chirp) * _filter else: _pump_function = self._hermite_mode(self.pump_wavelength) * \ np.exp(1j * 2 * np.pi * _sol / self.pump_wavelength * self.pump_delay) * \ np.exp(1j * (2 * np.pi * _sol / self.pump_wavelength) ** 2 * self.pump_chirp) self.__pump_spectrum = _pump_function return _pump_function
def plot(self, ax=None, light_plot=False, **kwargs): if ax is None: fig, ax = plt.subplots(1, 1) if self.pump_spectrum is None: self.calculate_pump_spectrum() if light_plot: x, y = self.wavelength1 * 1e9, self.wavelength2 * 1e9, ax.imshow(abs(self.pump_spectrum) ** 2, origin="lower", extent=[x.min(), x.max(), y.min(), y.max()], aspect="auto") warnings.warn("The light_plot mode is compatible only with linear meshes of the signal/idler wavelengths.") else: ax.pcolormesh(self.wavelength1 * 1e9, self.wavelength2 * 1e9, abs(self.pump_spectrum) ** 2) ax.set_title("Pump intensity") if self.process == Process.SFG or self.process == Process.DFG: ax.set_xlabel(r"$\lambda_{input}$ [nm]") ax.set_ylabel(r"$\lambda_{output}$ [nm]") else: ax.set_xlabel(r"$\lambda_{signal}$ [nm]") ax.set_ylabel(r"$\lambda_{idler}$ [nm]") plt.tight_layout()
[docs]class JSA(object): def __init__(self, phasematching: Union[phasematching.SimplePhasematching2D, phasematching.Phasematching2D], pump: Pump): self.__phasematching = phasematching self.__pump = None self.pump = pump self.__JSA = None self.__JSI = None self.__K = None self.__marginal1 = None self.__marginal2 = None self.__singular_values = None @property def phasematching(self): return self.__phasematching @phasematching.setter def phasematching(self, value): self.__phasematching = value @property def pump(self): return self.__pump @pump.setter def pump(self, value: Pump): if not isinstance(value, Pump): raise TypeError("The pump must be an object of the class pynumpm.jsa.Pump") if value.pump_spectrum is None: value.calculate_pump_spectrum() self.__pump = value @property def jsa(self): return self.__JSA @property def jsi(self): return self.__JSI @property def schmidt_number(self): return self.__K @property def marginal1(self): return self.__marginal1 @property def marginal2(self): return self.__marginal2 @property def singular_values(self): return self.__singular_values
[docs] def calculate_JSA(self): """ Function to calculate the JSA. :param pump_width: Pump object. Signal and idler wavelengths of the pump are overwritten to match the one of the phasematching process :type pump: :class:`~pynumpm.jsa.Pump` :return: """ logger = logging.getLogger(__name__) logger.info("Calculating JSA") signal_wl = self.phasematching.wavelength1 idler_wl = self.phasematching.wavelength2 # d_wl_signal = np.diff(signal_wl)[0] # d_wl_idler = np.diff(self.phasematching.wavelength2)[0] WL_SIGNAL, WL_IDLER = np.meshgrid(signal_wl, idler_wl) self.pump.wavelength1 = WL_SIGNAL self.pump.wavelength2 = WL_IDLER JSA = self.pump.pump_spectrum * self.phasematching.phi d_wl_signal = np.gradient(signal_wl) d_wl_idler = np.gradient(idler_wl) DWSIG, DWIDL = np.meshgrid(d_wl_signal, d_wl_idler) JSA /= np.sqrt((abs(JSA * DWSIG * DWIDL) ** 2).sum()) JSI = abs(JSA) ** 2 self.__marginal1 = (JSI * abs(DWIDL)).sum(axis=0) self.__marginal2 = (JSI * abs(DWSIG)).sum(axis=1) self.__JSA = JSA self.__JSI = JSI return self.__JSA, self.__JSI
[docs] def calculate_schmidt_decomposition(self, verbose=False): """ Function to calculate the Schmidt decomposition. :param bool verbose: Print to screen the Schmidt number and the purity of the state. :return: the Schmidt number. """ logger = logging.getLogger(__name__) if self.jsa is None: self.calculate_JSA() U, self.__singular_values, V = np.linalg.svd(self.__JSA) self.__singular_values /= np.sqrt((self.__singular_values ** 2).sum()) self.__K = 1 / (self.__singular_values ** 4).sum() text = "Schmidt number K: {K}\nPurity: {P}".format(K=self.schmidt_number, P=1 / self.schmidt_number) if verbose: print(text) logger.info(text) logger.debug("Check normalization: sum of s^2 = " + str((abs(self.__singular_values) ** 2).sum())) return self.__K, U, self.__singular_values
[docs] def plot_schmidt_coefficients(self, ncoeff: int = 20, ax: plt.Axes = None): """ Function to plot the first n distribution of the Schmidt coefficients. :param ncoeff: Number of coefficients to plot. Default: 20 :type ncoeff: int :param ax: Handles to the axis object where to plot. :param ax: `matplotlib.axes.Axes` :return: """ if self.__singular_values is None: self.calculate_schmidt_decomposition() if ax is None: plt.figure() ax = plt.gca() plt.sca(ax) plt.bar(range(ncoeff), self.__singular_values[:ncoeff], 0.8) plt.xlabel("Mode number $i$") plt.ylabel(r"$\sqrt{\lambda_i}$") return ax
[docs] def plot(self, ax=None, light_plot=False, normalized=True, title="JSI", plot_pump=False): """ Function to plot JSI. Pass ax handle through "ax" to plot in a specified axis environment. :param ax: Axes handles :type ax: matplotlib.pyplot.axes :param light_plot: Flag to allow plotting in the *light* mode. The light_plot mode is compatible only with linear meshes of the signal/idler wavelengths. Default is False. :type light_plot: bool :param normalized: Flag to plot the JSI normalized or unnormalized. Default is True. :type normalized: bool :param plot_pump: Flag to plot the pump spectrum overlayed as contour plot. Default is False. :type plot_pump: bool. :param kwargs: :return: the axes handle for the plot """ logger = logging.getLogger(__name__) if self.jsa is None: logger.info("The JSA was not calculated. I'll try to calculate it right away.") self.calculate_JSA() # raise ValueError("You need to calculate the JSA first, use the command calculate_jsa()") if ax is None: fig, ax = plt.subplots(1, 1) title = title x = self.phasematching.wavelength1 * 1e9 y = self.phasematching.wavelength2 * 1e9 jsi_to_plot = self.jsi if normalized: logger.debug("The user wants the normalised JSI.") jsi_to_plot /= jsi_to_plot.max() if light_plot: im = ax.imshow(jsi_to_plot, origin="lower", extent=[x.min(), x.max(), y.min(), y.max()], aspect="auto") warnings.warn("The light_plot mode is compatible only with linear meshes of the signal/idler wavelengths.") else: im = ax.pcolormesh(x, y, jsi_to_plot) if plot_pump: logger.debug("The user wants the pump contours on the JSI plot.") X, Y = np.meshgrid(x, y) Z = abs(self.pump.pump_spectrum) ** 2 sigmas = np.arange(0, 4, 1)[::-1] levels = np.exp(-sigmas) warnings.warn("The sigmas indicated in the JSI plot refer to the intensity of the pump. " "If you want the field information, you need to convert them accordingly.") labels_dict = {level: str(len(levels) - i - 1) + r"$\sigma$" for i, level in enumerate(levels)} CS = ax.contour(X, Y, Z / Z.max(), levels, colors="w", linestyles="-", linewidths=levels * 6) ax.clabel(CS, fmt=labels_dict, fontsize=9, inline=1) plt.gcf().colorbar(im) ax.set_xlabel(r"$\lambda_{signal}$ [nm]") ax.set_ylabel(r"$\lambda_{idler}$ [nm]") ax.set_title(title) plt.tight_layout() return ax
[docs] def plot_marginals(self, ax=None, **kwargs): """ Function to plot the marginals of the JSI. :param ax: Axes handles for the *two* axes where to draw the marginals. The input can be None or a list of the two axes handles. If None, a new plot is generated with two subplots, one for each marginal. Default: *None* :type ax: matplotlib.pyplot.axes :return: the Axes handles """ if ax is None: fig, ax = plt.subplots(1, 2) else: if type(ax) != list or (type(ax) == list and len(ax) != 2): raise ValueError( "I need two different axes to plot the marginals. ax should be a list of axes handles [ax0, ax1]") suptitle = kwargs.get("suptitle", "Marginals") print(self.phasematching.wavelength1 * 1e9) print(self.__marginal1) ax[0].plot(self.phasematching.wavelength1 * 1e9, self.marginal1) ax[1].plot(self.phasematching.wavelength2 * 1e9, self.marginal2) plt.suptitle(suptitle) return ax