Source code for flexcode.basis_functions

"""Functions for evaluation of orthogonal basis functions."""

import numpy as np
import pywt

from .helpers import box_transform, make_grid
from .post_processing import *


[docs]def evaluate_basis(responses, n_basis, basis_system): """Evaluates a system of basis functions. Arguments ---------- responses : array An array of responses in [0, 1]. n_basis : integer The number of basis functions to calculate. basis_system : {'cosine', 'Fourier', 'db4'} String denoting the system of orthogonal basis functions. Returns ------- numpy matrix A matrix of basis functions evaluations. Each row corresponds to a value of `responses`, each column corresponds to a basis function. Raises ------ ValueError If the basis system isn't recognized. """ systems = {"cosine": cosine_basis, "Fourier": fourier_basis, "db4": wavelet_basis} try: basis_fn = systems[basis_system] except KeyError: raise ValueError("Basis system {} not recognized".format(basis_system)) n_dim = responses.shape[1] if n_dim == 1: return basis_fn(responses, n_basis) else: if len(n_basis) == 1: n_basis = [n_basis] * n_dim return tensor_basis(responses, n_basis, basis_fn)
[docs]def tensor_basis(responses, n_basis, basis_fn): """Evaluates tensor basis. Combines single-dimensional basis functions \phi_{d}(z) to form orthogonal tensor basis $\phi(z_{1}, \dots, z_{D}) = \prod_{d} \phi_{d}(z_{d})$. Arguments --------- responses : numpy matrix A matrix of responses in [0, 1]^(n_dim). Each column corresponds to a variable, each row corresponds to an observation. n_basis : list of integers The number of basis function for each dimension. Should have the same length as the number of columns of `responses`. basis_fn : function The function which evaluates the one-dimensional basis functions. Returns ------- numpy matrix Returns a matrix where each column is a basis function and each row is an observation. """ n_obs, n_dims = responses.shape basis = np.ones((n_obs, np.prod(n_basis))) period = 1 for dim in range(n_dims): sub_basis = basis_fn(responses[:, dim], n_basis[dim]) col = 0 for _ in range(np.prod(n_basis) // (n_basis[dim] * period)): for sub_col in range(n_basis[dim]): for _ in range(period): basis[:, col] *= sub_basis[:, sub_col] col += 1 period *= n_basis[dim] return basis
[docs]def cosine_basis(responses, n_basis): """Evaluates cosine basis. Arguments ---------- responses : array An array of responses in [0, 1]. n_basis : integer The number of basis functions to evaluate. Returns ------- numpy matrix A matrix of cosine basis functions evaluated at `responses`. Each row corresponds to a value of `responses`, each column corresponds to a basis function. """ n_obs = responses.shape[0] basis = np.empty((n_obs, n_basis)) responses = responses.flatten() basis[:, 0] = 1.0 for col in range(1, n_basis): basis[:, col] = np.sqrt(2) * np.cos(np.pi * col * responses) return basis
[docs]def fourier_basis(responses, n_basis): """Evaluates Fourier basis. Arguments ---------- responses : array An array of responses in [0, 1]. n_basis : integer The number of basis functions to evaluate. Returns ------- numpy matrix A matrix of Fourier basis functions evaluated at `responses`. Each row corresponds to a value of `responses`, each column corresponds to a basis function. """ n_obs = responses.shape[0] basis = np.zeros((n_obs, n_basis)) responses = responses.flatten() basis[:, 0] = 1.0 for col in range(1, (n_basis + 1) // 2): basis[:, 2 * col - 1] = np.sqrt(2) * np.sin(2 * np.pi * col * responses) basis[:, 2 * col] = np.sqrt(2) * np.cos(2 * np.pi * col * responses) if n_basis % 2 == 0: basis[:, -1] = np.sqrt(2) * np.sin(np.pi * n_basis * responses) return basis
[docs]def wavelet_basis(responses, n_basis, family="db4"): """Evaluates Daubechies basis. Arguments ---------- responses : array An array of responses in [0, 1]. n_basis : integer The number of basis functions to evaluate. family : string The wavelet family to evaluate. Returns ------- numpy matrix A matrix of Fourier basis functions evaluated at `responses`. Each row corresponds to a value of `responses`, each column corresponds to a basis function. """ responses = responses.flatten() n_aux = 15 rez = pywt.DiscreteContinuousWavelet(family).wavefun(n_aux) if len(rez) == 2: wavelet, x_grid = rez else: _, wavelet, x_grid = rez wavelet *= np.sqrt(max(x_grid) - min(x_grid)) x_grid = (x_grid - min(x_grid)) / (max(x_grid) - min(x_grid)) def _wave_fun(val): if val < 0 or val > 1: return 0.0 return wavelet[np.argmin(abs(val - x_grid))] n_obs = responses.shape[0] basis = np.empty((n_obs, n_basis)) basis[:, 0] = 1.0 loc = 0 level = 0 for col in range(1, n_basis): basis[:, col] = [2 ** (level / 2) * _wave_fun(a * 2**level - loc) for a in responses] loc += 1 if loc == 2**level: loc = 0 level += 1 return basis
[docs]class BasisCoefs(object): def __init__(self, coefs, basis_system, z_min, z_max, bump_threshold=None, sharpen_alpha=None): self.coefs = coefs self.basis_system = basis_system self.z_min = z_min self.z_max = z_max self.bump_threshold = bump_threshold self.sharpen_alpha = sharpen_alpha
[docs] def evaluate(self, z_grid): basis = evaluate_basis( box_transform(z_grid, self.z_min, self.z_max), self.coefs.shape[1], self.basis_system ) cdes = np.matmul(self.coefs, basis.T) normalize(cdes) if self.bump_threshold is not None: remove_bumps(cdes, self.bump_threshold) if self.sharpen_alpha is not None: sharpen(cdes, self.sharpen_alpha) cdes /= self.z_max - self.z_min return cdes