Source code for wwinpy._utils

"""Utility functions for WWINP file processing.

Provides helper functions for data verification and grid operations.
"""

# utils.py

from typing import List, Optional, Tuple
import numpy as np

[docs] def verify_and_correct(ni: int, nt: Optional[List[int]], ne: List[int], iv: int, verbose: bool = False) -> Tuple[int, Optional[List[int]], List[int]]: """Verify and correct weight window input parameters. :param ni: Number of particle types :type ni: int :param nt: Time groups per particle type (None if iv != 2) :type nt: Optional[List[int]] :param ne: Energy groups per particle type :type ne: List[int] :param iv: Time dependency flag (2 for time-dependent) :type iv: int :param verbose: Enable detailed output :type verbose: bool :return: Tuple of (corrected ni, corrected nt, corrected ne) :rtype: Tuple[int, Optional[List[int]], List[int]] :raises ValueError: If input parameters are inconsistent """ changes_made = False # Step 1: Verify lengths if iv == 2 and nt is not None: if len(nt) != len(ne): min_length = min(len(nt), len(ne)) if len(nt) != min_length or len(ne) != min_length: if verbose: print( f"Warning: Length of nt ({len(nt)}) and ne ({len(ne)}) do not match. Truncating to {min_length}." ) nt = nt[:min_length] ne = ne[:min_length] ni = min_length changes_made = True # Step 2: Verify lengths match ni if iv == 2 and nt is not None: if len(ne) != ni or len(nt) != ni: if verbose: print( f"Warning: Length of ne ({len(ne)}) or nt ({len(nt)}) does not match ni ({ni}). Adjusting ni to {min(len(ne), len(nt))}." ) ni = min(len(ne), len(nt)) ne = ne[:ni] nt = nt[:ni] changes_made = True else: if len(ne) != ni: if verbose: print( f"Warning: Length of ne ({len(ne)}) does not match ni ({ni}). Adjusting ni to {len(ne)}." ) ni = len(ne) ne = ne[:ni] changes_made = True # Step 3: Identify indices with ne == 0 zero_ne_indices = {i for i, val in enumerate(ne) if val == 0} if zero_ne_indices: changes_made = True for i in sorted(zero_ne_indices, reverse=True): if iv == 2 and nt is not None: if nt[i] == 0: if verbose: print( f"Warning: Particle type {i} has 0 energy and 0 time groups. It has been deleted and ni updated." ) else: if verbose: print( f"Warning: Particle type {i} has 0 energy groups. It has been deleted and ni updated." ) else: if verbose: print( f"Warning: Particle type {i} has 0 energy groups. It has been deleted and ni updated." ) # Remove the particle type del ne[i] if iv == 2 and nt is not None: del nt[i] ni -= 1 # Step 4: If iv == 2, check for 0's in nt if iv == 2 and nt is not None: zero_nt_indices = {i for i, val in enumerate(nt) if val == 0} if zero_nt_indices: changes_made = True for i in sorted(zero_nt_indices, reverse=True): if verbose: print( f"Warning: Particle type {i} has 0 time groups. It has been deleted and ni updated." ) del nt[i] del ne[i] ni -= 1 # Step 5: Final length checks if iv == 2 and nt is not None: if len(ne) != ni or len(nt) != ni: min_length = min(len(ne), len(nt), ni) if len(ne) != min_length or len(nt) != min_length: if verbose: print( f"Warning: After corrections, lengths of ne ({len(ne)}) or nt ({len(nt)}) do not match ni ({ni}). Truncating lists to {min_length}." ) ne = ne[:min_length] nt = nt[:min_length] ni = min_length else: if len(ne) != ni: if verbose: print( f"Warning: After corrections, length of ne ({len(ne)}) does not match ni ({ni}). Truncating ne to {ni}." ) ne = ne[:ni] if changes_made: return ni, nt, ne else: if verbose: print("Header verification complete. No changes made.") return ni, nt, ne
[docs] def get_closest_energy_indices(energy_grid: np.ndarray, energy_value: float, atol: float = 1e-9) -> np.ndarray: """Find energy indices bounding a value in an energy grid. :param energy_grid: Sorted energy grid points array :type energy_grid: np.ndarray :param energy_value: Target energy value :type energy_value: float :param atol: Absolute tolerance for float comparison :type atol: float, optional :return: Array of one or two indices. Two indices only when value exactly matches a grid point :rtype: np.ndarray """ if len(energy_grid) == 0: raise ValueError("Energy grid cannot be empty") if energy_value > energy_grid[-1]: return np.array([len(energy_grid)-1]) # Use side='left' to get the index where the value would be inserted idx = np.searchsorted(energy_grid, energy_value, side='left') # Check if the value exactly matches a grid point if idx < len(energy_grid) and np.isclose(energy_grid[idx], energy_value, atol=atol): if idx == len(energy_grid) - 1: # Last point in grid return np.array([idx]) return np.array([idx, idx + 1]) # If we're at the start of the array, return first index if idx == 0: return np.array([0]) # Otherwise return the index to the left of where the value would be inserted return np.array([idx])
[docs] def get_range_energy_indices(grid: np.ndarray, range_tuple: Tuple[float, float], atol: float = 1e-9) -> np.ndarray: """Find energy indices for a range of values in an energy grid. :param grid: Sorted energy grid points array :type grid: np.ndarray :param range_tuple: (min, max) energy range values :type range_tuple: Tuple[float, float] :return: Array of indices covering the range :rtype: np.ndarray :raises ValueError: If grid is empty or range is invalid """ if not grid.size: # More robust empty check raise ValueError("Energy grid cannot be empty") v_min, v_max = range_tuple if v_min > v_max: raise ValueError(f"Invalid range: min {v_min} is greater than max {v_max}") # Find the starting index start_idx = np.searchsorted(grid, v_min) if start_idx == len(grid) or not np.isclose(grid[start_idx], v_min, atol) and grid[start_idx] < v_min: if start_idx == len(grid): return np.array([], dtype=int) start_idx += 1 # Find the ending index (Corrected Logic) end_idx = np.searchsorted(grid, v_max) if end_idx > 0 and not np.isclose(grid[end_idx-1], v_max, atol) and grid[end_idx-1] > v_max: end_idx -= 1 if start_idx > end_idx: return np.array([], dtype=int) return np.arange(start_idx, end_idx+1)
[docs] def get_energy_intervals_from_indices(bins: np.ndarray, indices: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """Get energy interval boundaries for given indices. :param bins: Energy grid boundaries array :type bins: np.ndarray :param indices: Array of indices from get_closest_energy_indices :type indices: np.ndarray :return: Tuple of (starts, ends) arrays for energy intervals :rtype: Tuple[np.ndarray, np.ndarray] """ if len(indices) == 0: return np.array([]), np.array([]) starts = [] ends = [] for idx in indices: if idx == 0: # For index 0, use 0.0 as lower bound starts.append(0.0) ends.append(bins[0]) else: # For other indices, use previous bin value and current bin value starts.append(bins[idx-1]) ends.append(bins[idx]) return np.array(starts), np.array(ends)
[docs] def get_closest_indices(grid: np.ndarray, value: float, atol: float = 1e-9) -> np.ndarray: """Find indices bounding a value in a grid. :param grid: Sorted grid points array :type grid: np.ndarray :param value: Target value :type value: float :param atol: Absolute tolerance for float comparison :type atol: float, optional :return: Array of two bounding indices :rtype: np.ndarray """ if len(grid) == 2: if value < grid[0]: print(f"Warning: Value {value:.4e} is below the grid range. Using first grid point {grid[0]:.4e}.") if value > grid[-1]: print(f"Warning: Value {value:.4e} is above the grid range. Using last grid point {grid[-1]:.4e}.") return np.array([0]) if value < grid[0]: print(f"Warning: Value {value:.4e} is below the grid range. Using first grid point {grid[0]:.4e}.") return np.array([0, 1]) if value > grid[-1]: print(f"Warning: Value {value:.4e} is above the grid range. Using last grid point {grid[-1]:.4e}.") return np.array([len(grid) - 2, len(grid) - 1]) if value < grid[0]: return np.array([0, 1]) if value > grid[-1]: return np.array([len(grid) - 2, len(grid) - 1]) idx = np.searchsorted(grid, value) if idx == 0: return np.array([0, 1]) if idx == len(grid): return np.array([len(grid) - 2, len(grid) - 1]) if np.isclose(grid[idx], value, atol=atol): if idx == 0: return np.array([0, 1]) elif idx == len(grid) - 1: return np.array([len(grid) - 2, len(grid) - 1]) else: return np.array([idx - 1, idx + 1]) return np.array([idx - 1, idx])
[docs] def get_range_indices(grid: np.ndarray, range_tuple: Tuple[float, float]) -> np.ndarray: """Find grid indices within a range. :param grid: Sorted grid points array :type grid: np.ndarray :param range_tuple: (min, max) range values :type range_tuple: Tuple[float, float] :return: Array of indices within range :rtype: np.ndarray :raises ValueError: If range_tuple[0] > range_tuple[1] """ v_min, v_max = range_tuple # Handle grids with 0 or 1 elements if grid.size <= 1: return np.array([0, np.inf]) if v_min > v_max: raise ValueError(f"Invalid range: min {v_min} is greater than max {v_max}.") # Initialize lower and upper indices lower_idx = None upper_idx = None # Check if v_min is exactly on the grid exact_min = np.isclose(grid, v_min, atol=1e-9) if exact_min.any(): lower_idx = np.argmax(exact_min) elif v_min < grid[0]: print(f"Warning: Lower bound {v_min} is below the grid range. Using first grid point {grid[0]:.4e} as the lower limit.") lower_idx = 0 else: # Find the closest lower grid point lower_idx = np.searchsorted(grid, v_min, side='right') - 1 lower_idx = max(lower_idx, 0) # Ensure non-negative # Check if v_max is exactly on the grid exact_max = np.isclose(grid, v_max, atol=1e-9) if exact_max.any(): upper_idx = np.argmax(exact_max) elif v_max > grid[-1]: print(f"Warning: Upper bound {v_max} is above the grid range. Using last grid point {grid[-1]:.4e} as the upper limit.") upper_idx = grid.size - 1 else: # Find the closest higher grid point upper_idx = np.searchsorted(grid, v_max, side='left') if upper_idx >= grid.size: upper_idx = grid.size - 1 # Ensure within bounds # Ensure upper_idx is not less than lower_idx upper_idx = max(upper_idx, lower_idx) # Collect all indices within [lower_idx, upper_idx] indices = np.arange(lower_idx, upper_idx + 1) return indices
[docs] def get_bin_intervals_from_indices(bins: np.ndarray, indices: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """Get bin interval boundaries for given indices. :param bins: Array of bin boundaries :type bins: np.ndarray :param indices: Array of [start, end] indices :type indices: np.ndarray :return: Tuple of (bin_starts, bin_ends) arrays :rtype: Tuple[np.ndarray, np.ndarray] :raises ValueError: If bins is None/empty or indices invalid """ if bins is None: raise ValueError("Bins must have at least one value.") if len(indices) == 1 and indices[0] == 0: return np.array([bins[0]]), np.array([bins[1]]) if indices[0] == 0 and (indices[1] == np.inf or indices[1] == 0): return 0.0, float('inf') start_idx = indices[0] end_idx = indices[-1] if start_idx > end_idx or start_idx < 0 or end_idx > len(bins): raise ValueError("Indices must define a valid range within the bins.") bin_starts = bins[start_idx:end_idx] bin_ends = bins[start_idx + 1:end_idx + 1] return bin_starts, bin_ends