from typing import List, Optional, Tuple import numpy as np from ctranslate2.specs import common_spec, model_spec, transformer_spec class Wav2Vec2Config(model_spec.ModelConfig): """Configuration for the Wav2Vec2 model.""" def __init__(self): return class Wav2Vec2Spec(model_spec.LanguageModelSpec): def __init__(self, feat_layers, num_layers, num_heads): super().__init__() self.encoder = Wav2Vec2EncoderSpec(feat_layers, num_layers, num_heads) @property def name(self): return "Wav2Vec2Spec" @property def revision(self): return 3 def get_default_config(self): return Wav2Vec2Config() def get_vocabulary_size(self): return self.encoder.lm_head.weight.shape[0] class Wav2Vec2LayerNormConvLayer(model_spec.LayerSpec): def __init__(self): self.conv = common_spec.Conv1DSpec() self.layer_norm = common_spec.LayerNormSpec() class Wav2Vec2PosEmbedConvLayer(model_spec.LayerSpec): def __init__(self): self.conv = common_spec.Conv1DSpec() class Wav2Vec2EncoderSpec(model_spec.LayerSpec): def __init__(self, feat_layers, num_layers, num_heads): self.num_heads = np.dtype("int16").type(num_heads) self.feat_layer0 = Wav2Vec2LayerNormConvLayer() self.feat_layer = [Wav2Vec2LayerNormConvLayer() for i in range(feat_layers - 1)] self.fp_layer_norm = common_spec.LayerNormSpec() self.fp_projection = common_spec.LinearSpec() self.pos_conv_embed = Wav2Vec2PosEmbedConvLayer() self.layer_norm = common_spec.LayerNormSpec() self.layer = [ transformer_spec.TransformerEncoderLayerSpec() for _ in range(num_layers) ] self.lm_head = common_spec.LinearSpec()