Skip to content

Commit

Permalink
chore(format): run black on dev (#573)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] and github-actions[bot] authored Jul 16, 2024
1 parent 27331c3 commit 53476da
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 41 deletions.
43 changes: 24 additions & 19 deletions ChatTTS/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


@dataclass(repr=False, eq=False)
class Path():
class Path:
vocos_ckpt_path: str = "asset/Vocos.pt"
dvae_ckpt_path: str = "asset/DVAE.pt"
gpt_ckpt_path: str = "asset/GPT.pt"
Expand All @@ -11,7 +11,7 @@ class Path():


@dataclass(repr=False, eq=False)
class Decoder():
class Decoder:
idim: int = 384
odim: int = 384
hidden: int = 512
Expand All @@ -20,26 +20,27 @@ class Decoder():


@dataclass(repr=False, eq=False)
class VQ():
class VQ:
dim: int = 1024
levels: tuple = (5,5,5,5)
levels: tuple = (5, 5, 5, 5)
G: int = 2
R: int = 2


@dataclass(repr=False, eq=False)
class DVAE():
class DVAE:
decoder: Decoder = Decoder(
idim=512,
odim=512,
hidden=256,
n_layer=12,
bn_dim=128,
idim=512,
odim=512,
hidden=256,
n_layer=12,
bn_dim=128,
)
vq: VQ = VQ()


@dataclass(repr=False, eq=False)
class GPT():
class GPT:
hidden_size: int = 768
intermediate_size: int = 3072
num_attention_heads: int = 12
Expand All @@ -54,53 +55,57 @@ class GPT():


@dataclass(repr=False, eq=False)
class FeatureExtractorInitArgs():
class FeatureExtractorInitArgs:
sample_rate: int = 24000
n_fft: int = 1024
hop_length: int = 256
n_mels: int = 100
padding: str = "center"


@dataclass(repr=False, eq=False)
class FeatureExtractor():
class FeatureExtractor:
class_path: str = "vocos.feature_extractors.MelSpectrogramFeatures"
init_args: FeatureExtractorInitArgs = FeatureExtractorInitArgs()


@dataclass(repr=False, eq=False)
class BackboneInitArgs():
class BackboneInitArgs:
input_channels: int = 100
dim: int = 512
intermediate_dim: int = 1536
num_layers: int = 8


@dataclass(repr=False, eq=False)
class Backbone():
class Backbone:
class_path: str = "vocos.models.VocosBackbone"
init_args: BackboneInitArgs = BackboneInitArgs()


@dataclass(repr=False, eq=False)
class FourierHeadInitArgs():
class FourierHeadInitArgs:
dim: int = 512
n_fft: int = 1024
hop_length: int = 256
padding: str = "center"


@dataclass(repr=False, eq=False)
class FourierHead():
class FourierHead:
class_path: str = "vocos.heads.ISTFTHead"
init_args: FourierHeadInitArgs = FourierHeadInitArgs()


@dataclass(repr=False, eq=False)
class Vocos():
class Vocos:
feature_extractor: FeatureExtractor = FeatureExtractor()
backbone: Backbone = Backbone()
head: FourierHead = FourierHead()


@dataclass(repr=False, eq=False)
class Config():
class Config:
path: Path = Path()
decoder: Decoder = Decoder()
dvae: DVAE = DVAE()
Expand Down
46 changes: 25 additions & 21 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,9 @@ def _load(
self.device = device
self.compile = compile

feature_extractor = instantiate_class(args=(), init=asdict(self.config.vocos.feature_extractor))
feature_extractor = instantiate_class(
args=(), init=asdict(self.config.vocos.feature_extractor)
)
backbone = instantiate_class(args=(), init=asdict(self.config.vocos.backbone))
head = instantiate_class(args=(), init=asdict(self.config.vocos.head))
vocos = (
Expand All @@ -272,23 +274,23 @@ def _load(
.eval()
)
assert vocos_ckpt_path, "vocos_ckpt_path should not be None"
vocos.load_state_dict(
torch.load(vocos_ckpt_path, weights_only=True, mmap=True)
)
vocos.load_state_dict(torch.load(vocos_ckpt_path, weights_only=True, mmap=True))
self.vocos = vocos
self.logger.log(logging.INFO, "vocos loaded.")

dvae = DVAE(
decoder_config=asdict(self.config.dvae.decoder),
vq_config=asdict(self.config.dvae.vq),
dim=self.config.dvae.decoder.idim,
coef=coef,
).to(device).eval()
dvae = (
DVAE(
decoder_config=asdict(self.config.dvae.decoder),
vq_config=asdict(self.config.dvae.vq),
dim=self.config.dvae.decoder.idim,
coef=coef,
)
.to(device)
.eval()
)
coef = str(dvae)
assert dvae_ckpt_path, "dvae_ckpt_path should not be None"
dvae.load_state_dict(
torch.load(dvae_ckpt_path, weights_only=True, mmap=True)
)
dvae.load_state_dict(torch.load(dvae_ckpt_path, weights_only=True, mmap=True))
self.dvae = dvae
self.logger.log(logging.INFO, "dvae loaded.")

Expand All @@ -303,9 +305,7 @@ def _load(
gpt.prepare(compile=compile and "cuda" in str(device))
self.gpt = gpt
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt")
assert os.path.exists(
spk_stat_path
), f"Missing spk_stat.pt: {spk_stat_path}"
assert os.path.exists(spk_stat_path), f"Missing spk_stat.pt: {spk_stat_path}"
spk_stat: torch.Tensor = torch.load(
spk_stat_path,
weights_only=True,
Expand All @@ -315,11 +315,15 @@ def _load(
self.std, self.mean = spk_stat.requires_grad_(False).chunk(2)
self.logger.log(logging.INFO, "gpt loaded.")

decoder = DVAE(
decoder_config=asdict(self.config.decoder),
dim=self.config.decoder.idim,
coef=coef,
).to(device).eval()
decoder = (
DVAE(
decoder_config=asdict(self.config.decoder),
dim=self.config.decoder.idim,
coef=coef,
)
.to(device)
.eval()
)
coef = str(decoder)
assert decoder_ckpt_path, "decoder_ckpt_path should not be None"
decoder.load_state_dict(
Expand Down
2 changes: 1 addition & 1 deletion ChatTTS/model/dvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class DVAE(nn.Module):
def __init__(
self,
decoder_config: dict,
vq_config: Optional[dict]=None,
vq_config: Optional[dict] = None,
dim=512,
coef: Optional[str] = None,
):
Expand Down

0 comments on commit 53476da

Please sign in to comment.