|
|
|
|
|
__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' |
|
|
|
|
|
import argparse |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import yaml |
|
|
import os |
|
|
import soundfile as sf |
|
|
from ml_collections import ConfigDict |
|
|
from omegaconf import OmegaConf |
|
|
from tqdm.auto import tqdm |
|
|
from typing import Dict, List, Tuple, Any, Union |
|
|
import loralib as lora |
|
|
import gc |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, filename='utils.log', format='%(asctime)s - %(message)s') |
|
|
|
|
|
def load_config(model_type: str, config_path: str) -> Union[ConfigDict, OmegaConf]: |
|
|
try: |
|
|
with open(config_path, 'r') as f: |
|
|
if model_type == 'htdemucs': |
|
|
config = OmegaConf.load(config_path) |
|
|
else: |
|
|
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader)) |
|
|
return config |
|
|
except FileNotFoundError: |
|
|
raise FileNotFoundError(f"Configuration file not found at {config_path}") |
|
|
except Exception as e: |
|
|
raise ValueError(f"Error loading configuration: {e}") |
|
|
|
|
|
def get_model_from_config(model_type: str, config_path: str) -> Tuple: |
|
|
""" |
|
|
Load the model specified by the model type and configuration file. |
|
|
|
|
|
Parameters: |
|
|
---------- |
|
|
model_type : str |
|
|
The type of model to load (e.g., 'mdx23c', 'htdemucs', 'scnet', etc.). |
|
|
config_path : str |
|
|
The path to the configuration file (YAML or OmegaConf format). |
|
|
|
|
|
Returns: |
|
|
------- |
|
|
model : nn.Module or None |
|
|
The initialized model based on the `model_type`, or None if the model type is not recognized. |
|
|
config : Any |
|
|
The configuration used to initialize the model. This could be in different formats |
|
|
depending on the model type (e.g., OmegaConf, ConfigDict). |
|
|
|
|
|
Raises: |
|
|
------ |
|
|
ValueError: |
|
|
If the `model_type` is unknown or an error occurs during model initialization. |
|
|
""" |
|
|
|
|
|
config = load_config(model_type, config_path) |
|
|
|
|
|
if model_type == 'mdx23c': |
|
|
from models.mdx23c_tfc_tdf_v3 import TFC_TDF_net |
|
|
model = TFC_TDF_net(config) |
|
|
elif model_type == 'htdemucs': |
|
|
from models.demucs4ht import get_model |
|
|
model = get_model(config) |
|
|
elif model_type == 'segm_models': |
|
|
from models.segm_models import Segm_Models_Net |
|
|
model = Segm_Models_Net(config) |
|
|
elif model_type == 'torchseg': |
|
|
from models.torchseg_models import Torchseg_Net |
|
|
model = Torchseg_Net(config) |
|
|
elif model_type == 'mel_band_roformer': |
|
|
from models.bs_roformer import MelBandRoformer |
|
|
model = MelBandRoformer(**dict(config.model)) |
|
|
elif model_type == 'bs_roformer': |
|
|
from models.bs_roformer import BSRoformer |
|
|
model = BSRoformer(**dict(config.model)) |
|
|
elif model_type == 'swin_upernet': |
|
|
from models.upernet_swin_transformers import Swin_UperNet_Model |
|
|
model = Swin_UperNet_Model(config) |
|
|
elif model_type == 'bandit': |
|
|
from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple |
|
|
model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model) |
|
|
elif model_type == 'bandit_v2': |
|
|
from models.bandit_v2.bandit import Bandit |
|
|
model = Bandit(**config.kwargs) |
|
|
elif model_type == 'scnet_unofficial': |
|
|
from models.scnet_unofficial import SCNet |
|
|
model = SCNet(**config.model) |
|
|
elif model_type == 'scnet': |
|
|
from models.scnet import SCNet |
|
|
model = SCNet(**config.model) |
|
|
elif model_type == 'apollo': |
|
|
from models.look2hear.models import BaseModel |
|
|
model = BaseModel.apollo(**config.model) |
|
|
elif model_type == 'bs_mamba2': |
|
|
from models.ts_bs_mamba2 import Separator |
|
|
model = Separator(**config.model) |
|
|
elif model_type == 'experimental_mdx23c_stht': |
|
|
from models.mdx23c_tfc_tdf_v3_with_STHT import TFC_TDF_net |
|
|
model = TFC_TDF_net(config) |
|
|
else: |
|
|
raise ValueError(f"Unknown model type: {model_type}") |
|
|
|
|
|
return model, config |
|
|
|
|
|
def read_audio_transposed(path: str, instr: str = None, skip_err: bool = False) -> Tuple[np.ndarray, int]: |
|
|
try: |
|
|
mix, sr = sf.read(path) |
|
|
if len(mix.shape) == 1: |
|
|
mix = np.expand_dims(mix, axis=-1) |
|
|
return mix.T, sr |
|
|
except Exception as e: |
|
|
if skip_err: |
|
|
print(f"No stem {instr}: skip!") |
|
|
return None, None |
|
|
raise RuntimeError(f"Error reading the file at {path}: {e}") |
|
|
|
|
|
def normalize_audio(audio: np.ndarray) -> Tuple[np.ndarray, Dict[str, float]]: |
|
|
mono = audio.mean(0) |
|
|
mean, std = mono.mean(), mono.std() |
|
|
return (audio - mean) / (std + 1e-8), {"mean": mean, "std": std} |
|
|
|
|
|
def denormalize_audio(audio: np.ndarray, norm_params: Dict[str, float]) -> np.ndarray: |
|
|
return audio * norm_params["std"] + norm_params["mean"] |
|
|
|
|
|
def apply_tta( |
|
|
config, |
|
|
model: nn.Module, |
|
|
mix: torch.Tensor, |
|
|
waveforms_orig: Dict[str, torch.Tensor], |
|
|
device: str, |
|
|
model_type: str, |
|
|
progress=None |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
track_proc_list = [mix[::-1].clone(), -mix.clone()] |
|
|
total_steps = len(track_proc_list) |
|
|
processed_steps = 0 |
|
|
|
|
|
for i, augmented_mix in enumerate(track_proc_list): |
|
|
|
|
|
processed_steps += 1 |
|
|
progress_value = round((processed_steps / total_steps) * 50) |
|
|
if progress is not None and callable(getattr(progress, '__call__', None)): |
|
|
progress(progress_value / 100, desc=f"Applying TTA step {processed_steps}/{total_steps}") |
|
|
update_progress_html(f"Applying TTA step {processed_steps}/{total_steps}", progress_value) |
|
|
|
|
|
waveforms = demix(config, model, augmented_mix, device, model_type=model_type, pbar=False, progress=progress) |
|
|
for el in waveforms: |
|
|
if i == 0: |
|
|
waveforms_orig[el] += waveforms[el][::-1].clone() |
|
|
else: |
|
|
waveforms_orig[el] -= waveforms[el] |
|
|
del waveforms, augmented_mix |
|
|
gc.collect() |
|
|
if device.startswith('cuda'): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
for el in waveforms_orig: |
|
|
waveforms_orig[el] /= (len(track_proc_list) + 1) |
|
|
|
|
|
|
|
|
if progress is not None and callable(getattr(progress, '__call__', None)): |
|
|
progress(0.5, desc="TTA completed") |
|
|
update_progress_html("TTA completed", 50) |
|
|
|
|
|
return waveforms_orig |
|
|
|
|
|
def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor: |
|
|
fadein = torch.linspace(0, 1, fade_size) |
|
|
fadeout = torch.linspace(1, 0, fade_size) |
|
|
window = torch.ones(window_size) |
|
|
window[-fade_size:] = fadeout |
|
|
window[:fade_size] = fadein |
|
|
return window |
|
|
|
|
|
def demix( |
|
|
config: ConfigDict, |
|
|
model: nn.Module, |
|
|
mix: torch.Tensor, |
|
|
device: str, |
|
|
model_type: str, |
|
|
pbar: bool = False, |
|
|
progress=None |
|
|
) -> Dict[str, np.ndarray]: |
|
|
logging.info(f"Starting demix for model_type: {model_type}, chunk_size: {config.audio.chunk_size}") |
|
|
|
|
|
|
|
|
mix = torch.tensor(mix, dtype=torch.float16, device='cpu') |
|
|
mode = 'demucs' if model_type == 'htdemucs' else 'generic' |
|
|
|
|
|
|
|
|
if mode == 'demucs': |
|
|
chunk_size = config.training.samplerate * config.training.segment |
|
|
num_instruments = len(config.training.instruments) |
|
|
num_overlap = config.inference.num_overlap |
|
|
step = chunk_size // num_overlap |
|
|
else: |
|
|
chunk_size = config.audio.chunk_size |
|
|
num_instruments = len(prefer_target_instrument(config)) |
|
|
num_overlap = config.inference.num_overlap |
|
|
fade_size = chunk_size // 10 |
|
|
step = chunk_size // num_overlap |
|
|
border = chunk_size - step |
|
|
length_init = mix.shape[-1] |
|
|
windowing_array = _getWindowingArray(chunk_size, fade_size).to('cpu', dtype=torch.float16) |
|
|
if length_init > 2 * border and border > 0: |
|
|
mix = nn.functional.pad(mix, (border, border), mode="reflect") |
|
|
|
|
|
batch_size = getattr(config.inference, 'batch_size', 1) |
|
|
|
|
|
|
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
total_chunks = (mix.shape[1] + step - 1) // step |
|
|
processed_chunks = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
with torch.cuda.amp.autocast(enabled=device.startswith('cuda'), dtype=torch.float16): |
|
|
req_shape = (num_instruments,) + mix.shape |
|
|
result = torch.zeros(req_shape, dtype=torch.float16, device='cpu') |
|
|
counter = torch.zeros(req_shape, dtype=torch.float16, device='cpu') |
|
|
|
|
|
i = 0 |
|
|
batch_data = [] |
|
|
batch_locations = [] |
|
|
start_time = time.time() |
|
|
|
|
|
while i < mix.shape[1]: |
|
|
part = mix[:, i:i + chunk_size] |
|
|
chunk_len = part.shape[-1] |
|
|
pad_mode = "reflect" if mode == "generic" and chunk_len > chunk_size // 2 else "constant" |
|
|
part = nn.functional.pad(part, (0, chunk_size - chunk_len), mode=pad_mode, value=0) |
|
|
|
|
|
batch_data.append(part) |
|
|
batch_locations.append((i, chunk_len)) |
|
|
i += step |
|
|
|
|
|
if len(batch_data) >= batch_size or i >= mix.shape[1]: |
|
|
|
|
|
arr = torch.stack(batch_data, dim=0).to(device, non_blocking=True) |
|
|
x = model(arr) |
|
|
|
|
|
|
|
|
x = x.cpu() |
|
|
|
|
|
if mode == "generic": |
|
|
window = windowing_array.clone() |
|
|
if i - step == 0: |
|
|
window[:fade_size] = 1 |
|
|
elif i >= mix.shape[1]: |
|
|
window[-fade_size:] = 1 |
|
|
|
|
|
for j, (start, seg_len) in enumerate(batch_locations): |
|
|
if mode == "generic": |
|
|
result[..., start:start + seg_len] += (x[j, ..., :seg_len] * window[..., :seg_len]) |
|
|
counter[..., start:start + seg_len] += window[..., :seg_len] |
|
|
else: |
|
|
result[..., start:start + seg_len] += x[j, ..., :seg_len] |
|
|
counter[..., start:start + seg_len] += 1.0 |
|
|
|
|
|
|
|
|
processed_chunks += len(batch_data) |
|
|
progress_value = min(round((processed_chunks / total_chunks) * 100), 100) |
|
|
if progress is not None and callable(getattr(progress, '__call__', None)): |
|
|
progress(progress_value / 100, desc=f"Processing chunk {processed_chunks}/{total_chunks}") |
|
|
update_progress_html(f"Processing chunk {processed_chunks}/{total_chunks}", progress_value) |
|
|
|
|
|
del arr, x |
|
|
batch_data.clear() |
|
|
batch_locations.clear() |
|
|
gc.collect() |
|
|
if device.startswith('cuda'): |
|
|
torch.cuda.empty_cache() |
|
|
logging.info("Cleared CUDA cache") |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
|
logging.info(f"Demix completed in {elapsed_time:.2f} seconds") |
|
|
|
|
|
estimated_sources = result / (counter + 1e-8) |
|
|
estimated_sources = estimated_sources.numpy().astype(np.float32) |
|
|
np.nan_to_num(estimated_sources, copy=False, nan=0.0) |
|
|
|
|
|
if mode == "generic" and length_init > 2 * border and border > 0: |
|
|
estimated_sources = estimated_sources[..., border:-border] |
|
|
|
|
|
instruments = config.training.instruments if mode == "demucs" else prefer_target_instrument(config) |
|
|
ret_data = {k: v for k, v in zip(instruments, estimated_sources)} |
|
|
logging.info("Demix completed successfully") |
|
|
|
|
|
|
|
|
if progress is not None and callable(getattr(progress, '__call__', None)): |
|
|
progress(1.0, desc="Demix completed") |
|
|
update_progress_html("Demix completed", 100) |
|
|
|
|
|
return ret_data |
|
|
|
|
|
def prefer_target_instrument(config: ConfigDict) -> List[str]: |
|
|
return [config.training.target_instrument] if getattr(config.training, 'target_instrument', None) else config.training.instruments |
|
|
|
|
|
def load_not_compatible_weights(model: nn.Module, weights: str, verbose: bool = False) -> None: |
|
|
new_model = model.state_dict() |
|
|
old_model = torch.load(weights, map_location='cpu') |
|
|
if 'state' in old_model: |
|
|
old_model = old_model['state'] |
|
|
if 'state_dict' in old_model: |
|
|
old_model = old_model['state_dict'] |
|
|
for el in new_model: |
|
|
if el in old_model and new_model[el].shape == old_model[el].shape: |
|
|
new_model[el] = old_model[el] |
|
|
model.load_state_dict(new_model) |
|
|
|
|
|
def load_lora_weights(model: nn.Module, lora_path: str, device: str = 'cpu') -> None: |
|
|
lora_state_dict = torch.load(lora_path, map_location=device) |
|
|
model.load_state_dict(lora_state_dict, strict=False) |
|
|
|
|
|
def load_start_checkpoint(args: argparse.Namespace, model: nn.Module, type_='train') -> None: |
|
|
print(f'Start from checkpoint: {args.start_check_point}') |
|
|
device = 'cpu' |
|
|
state_dict = torch.load(args.start_check_point, map_location=device, weights_only=True) |
|
|
if args.model_type in ['htdemucs', 'apollo'] and isinstance(state_dict, dict): |
|
|
state_dict = state_dict.get('state', state_dict.get('state_dict', state_dict)) |
|
|
model.load_state_dict(state_dict) |
|
|
if args.lora_checkpoint: |
|
|
print(f"Loading LoRA weights from: {args.lora_checkpoint}") |
|
|
load_lora_weights(model, args.lora_checkpoint, device) |
|
|
|
|
|
def bind_lora_to_model(config: Dict[str, Any], model: nn.Module) -> nn.Module: |
|
|
if 'lora' not in config: |
|
|
raise ValueError("Configuration must contain the 'lora' key with parameters for LoRA.") |
|
|
replaced_layers = 0 |
|
|
for name, module in model.named_modules(): |
|
|
hierarchy = name.split('.') |
|
|
layer_name = hierarchy[-1] |
|
|
if isinstance(module, nn.Linear): |
|
|
try: |
|
|
parent_module = model |
|
|
for submodule_name in hierarchy[:-1]: |
|
|
parent_module = getattr(parent_module, submodule_name) |
|
|
setattr( |
|
|
parent_module, |
|
|
layer_name, |
|
|
lora.MergedLinear( |
|
|
in_features=module.in_features, |
|
|
out_features=module.out_features, |
|
|
bias=module.bias is not None, |
|
|
**config['lora'] |
|
|
) |
|
|
) |
|
|
replaced_layers += 1 |
|
|
except Exception as e: |
|
|
print(f"Error replacing layer {name}: {e}") |
|
|
print(f"Number of layers replaced with LoRA: {replaced_layers}") |
|
|
return model |
|
|
|
|
|
def draw_spectrogram(waveform, sample_rate, length, output_file): |
|
|
import librosa.display |
|
|
x = waveform[:int(length * sample_rate), :] |
|
|
X = librosa.stft(x.mean(axis=-1)) |
|
|
Xdb = librosa.amplitude_to_db(np.abs(X), ref=np.max) |
|
|
fig, ax = plt.subplots() |
|
|
img = librosa.display.specshow( |
|
|
Xdb, cmap='plasma', sr=sample_rate, x_axis='time', y_axis='linear', ax=ax |
|
|
) |
|
|
ax.set(title='File: ' + os.path.basename(output_file)) |
|
|
fig.colorbar(img, ax=ax, format="%+2.f dB") |
|
|
if output_file: |
|
|
plt.savefig(output_file) |
|
|
plt.close() |