import argparse import os from typing import Optional from ctranslate2.converters import utils from ctranslate2.converters.converter import Converter from ctranslate2.specs import common_spec, transformer_spec _SUPPORTED_MODELS = { "bart", "multilingual_transformer", "transformer", "transformer_align", "transformer_lm", } _SUPPORTED_ACTIVATIONS = { "gelu": common_spec.Activation.GELU, "gelu_accurate": common_spec.Activation.GELUTanh, "gelu_fast": common_spec.Activation.GELUTanh, "relu": common_spec.Activation.RELU, "swish": common_spec.Activation.SWISH, } def _get_model_spec(args): import fairseq activation_fn = getattr(args, "activation_fn", "relu") model_name = fairseq.models.ARCH_MODEL_NAME_REGISTRY[args.arch] check = utils.ConfigurationChecker() check( model_name in _SUPPORTED_MODELS, "Model '%s' used by architecture '%s' is not supported (supported models are: %s)" % (model_name, args.arch, ", ".join(_SUPPORTED_MODELS)), ) check.validate() check( activation_fn in _SUPPORTED_ACTIVATIONS, "Option --activation-fn %s is not supported (supported activations are: %s)" % (activation_fn, ", ".join(_SUPPORTED_ACTIVATIONS.keys())), ) check( not getattr(args, "no_token_positional_embeddings", False), "Option --no-token-positional-embeddings is not supported", ) check( not getattr(args, "lang_tok_replacing_bos_eos", False), "Option --lang-tok-replacing-bos-eos is not supported", ) if model_name == "transformer_lm": check( not args.character_embeddings, "Option --character-embeddings is not supported", ) check( not args.adaptive_input, "Option --adaptive-input is not supported", ) check.validate() return transformer_spec.TransformerDecoderModelSpec.from_config( args.decoder_layers, args.decoder_attention_heads, pre_norm=args.decoder_normalize_before, activation=_SUPPORTED_ACTIVATIONS[activation_fn], layernorm_embedding=getattr(args, "layernorm_embedding", False), no_final_norm=args.no_decoder_final_norm, project_in_out=args.decoder_input_dim != args.decoder_embed_dim, ) else: check( args.encoder_normalize_before == args.decoder_normalize_before, "Options --encoder-normalize-before and --decoder-normalize-before " "must have the same value", ) check( args.encoder_attention_heads == args.decoder_attention_heads, "Options --encoder-attention-heads and --decoder-attention-heads " "must have the same value", ) check.validate() return transformer_spec.TransformerSpec.from_config( (args.encoder_layers, args.decoder_layers), args.encoder_attention_heads, pre_norm=args.encoder_normalize_before, activation=_SUPPORTED_ACTIVATIONS[activation_fn], alignment_layer=getattr(args, "alignment_layer", -1), alignment_heads=getattr(args, "alignment_heads", 0), layernorm_embedding=getattr(args, "layernorm_embedding", False), ) def _get_vocab(dictionary): return ["" if token == "" else token for token in dictionary.symbols] class FairseqConverter(Converter): """Converts models trained with Fairseq.""" def __init__( self, model_path: str, data_dir: str, source_lang: Optional[str] = None, target_lang: Optional[str] = None, fixed_dictionary: Optional[str] = None, no_default_special_tokens: bool = False, user_dir: Optional[str] = None, ): """Initializes the Fairseq converter. Arguments: model_path: Path to the Fairseq PyTorch model (.pt file). data_dir: Path to the Fairseq data directory containing vocabulary files. source_lang: Source language (may be required if not declared in the model). target_lang: Target language (may be required if not declared in the model). fixed_dictionary: Path to the fixed dictionary for multilingual models. no_default_special_tokens: Require all special tokens to be provided by the user (e.g. encoder end token, decoder start token). user_dir: Path to the user directory containing custom extensions. """ self._model_path = model_path self._data_dir = data_dir self._fixed_dictionary = fixed_dictionary self._source_lang = source_lang self._target_lang = target_lang self._no_default_special_tokens = no_default_special_tokens self._user_dir = user_dir def _load(self): import fairseq import torch from fairseq import checkpoint_utils if self._user_dir: from fairseq.utils import import_user_module import_user_module(argparse.Namespace(user_dir=self._user_dir)) with torch.no_grad(): checkpoint = checkpoint_utils.load_checkpoint_to_cpu(self._model_path) args = checkpoint["args"] or checkpoint["cfg"]["model"] args.data = self._data_dir if self._fixed_dictionary is not None: args.fixed_dictionary = self._fixed_dictionary if hasattr(args, "lang_dict") and args.lang_dict: args.lang_dict = os.path.join( self._data_dir, os.path.basename(args.lang_dict) ) if self._source_lang is not None: args.source_lang = self._source_lang if self._target_lang is not None: args.target_lang = self._target_lang spec = _get_model_spec(args) task = fairseq.tasks.setup_task(args) model = fairseq.models.build_model(args, task) model.eval() model.load_state_dict(checkpoint["model"]) if isinstance(spec, transformer_spec.TransformerDecoderModelSpec): set_transformer_decoder( spec.decoder, model.decoder, with_encoder_attention=False, ) spec.register_vocabulary(_get_vocab(task.dictionary)) if not args.add_bos_token: spec.config.bos_token = spec.config.eos_token else: set_transformer_encoder(spec.encoder, model.encoder) set_transformer_decoder(spec.decoder, model.decoder) spec.register_source_vocabulary(_get_vocab(task.source_dictionary)) spec.register_target_vocabulary(_get_vocab(task.target_dictionary)) if self._no_default_special_tokens: spec.config.decoder_start_token = None else: spec.config.decoder_start_token = spec.config.eos_token spec.config.add_source_eos = True return spec def set_transformer_encoder(spec, module): set_input_layers(spec, module) for layer_spec, layer in zip(spec.layer, module.layers): set_transformer_encoder_layer(layer_spec, layer) if module.layer_norm is not None: set_layer_norm(spec.layer_norm, module.layer_norm) if module.layernorm_embedding is not None: set_layer_norm(spec.layernorm_embedding, module.layernorm_embedding) def set_transformer_decoder(spec, module, with_encoder_attention=True): set_input_layers(spec, module) set_linear(spec.projection, module.output_projection) for layer_spec, layer in zip(spec.layer, module.layers): set_transformer_decoder_layer( layer_spec, layer, with_encoder_attention=with_encoder_attention, ) if module.layer_norm is not None: set_layer_norm(spec.layer_norm, module.layer_norm) if module.layernorm_embedding is not None: set_layer_norm(spec.layernorm_embedding, module.layernorm_embedding) if module.project_in_dim is not None: set_linear(spec.project_in, module.project_in_dim) if module.project_out_dim is not None: set_linear(spec.project_out, module.project_out_dim) def set_input_layers(spec, module): set_position_encodings(spec.position_encodings, module.embed_positions) set_embeddings( spec.embeddings[0] if isinstance(spec.embeddings, list) else spec.embeddings, module.embed_tokens, ) spec.scale_embeddings = module.embed_scale def set_transformer_encoder_layer(spec, module): set_ffn(spec.ffn, module) set_multi_head_attention(spec.self_attention, module.self_attn, self_attention=True) set_layer_norm(spec.self_attention.layer_norm, module.self_attn_layer_norm) def set_transformer_decoder_layer(spec, module, with_encoder_attention=True): set_ffn(spec.ffn, module) set_multi_head_attention(spec.self_attention, module.self_attn, self_attention=True) set_layer_norm(spec.self_attention.layer_norm, module.self_attn_layer_norm) if with_encoder_attention: set_multi_head_attention(spec.attention, module.encoder_attn) set_layer_norm(spec.attention.layer_norm, module.encoder_attn_layer_norm) def set_ffn(spec, module): set_layer_norm(spec.layer_norm, module.final_layer_norm) set_linear(spec.linear_0, module.fc1) set_linear(spec.linear_1, module.fc2) def set_multi_head_attention(spec, module, self_attention=False): if self_attention: split_layers = [common_spec.LinearSpec() for _ in range(3)] set_linear(split_layers[0], module.q_proj) set_linear(split_layers[1], module.k_proj) set_linear(split_layers[2], module.v_proj) utils.fuse_linear(spec.linear[0], split_layers) else: set_linear(spec.linear[0], module.q_proj) split_layers = [common_spec.LinearSpec() for _ in range(2)] set_linear(split_layers[0], module.k_proj) set_linear(split_layers[1], module.v_proj) utils.fuse_linear(spec.linear[1], split_layers) set_linear(spec.linear[-1], module.out_proj) def set_layer_norm(spec, module): spec.gamma = module.weight.numpy() spec.beta = module.bias.numpy() def set_linear(spec, module): spec.weight = module.weight.numpy() if module.bias is not None: spec.bias = module.bias.numpy() def set_embeddings(spec, module): spec.weight = module.weight.numpy() def set_position_encodings(spec, module): import torch weight = module.weight if isinstance(module, torch.nn.Embedding) else module.weights spec.encodings = weight.numpy()[module.padding_idx + 1 :] def main(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("--model_path", required=True, help="Model path.") parser.add_argument( "--data_dir", required=True, help="Data directory containing the source and target vocabularies.", ) parser.add_argument( "--user_dir", help="Directory containing custom extensions.", ) parser.add_argument( "--fixed_dictionary", help="Fixed dictionary for multilingual models.", ) parser.add_argument( "--source_lang", help="Source language. This argument is used to find dictionary file from `data_dir`.", ) parser.add_argument( "--target_lang", help="Target language. This argument is used to find dictionary file from `data_dir`.", ) parser.add_argument( "--no_default_special_tokens", action="store_true", help=( "Require all special tokens to be provided by the user during inference, " "including the decoder start token." ), ) Converter.declare_arguments(parser) args = parser.parse_args() converter = FairseqConverter( args.model_path, args.data_dir, source_lang=args.source_lang, target_lang=args.target_lang, fixed_dictionary=args.fixed_dictionary, no_default_special_tokens=args.no_default_special_tokens, user_dir=args.user_dir, ) converter.convert_from_args(args) if __name__ == "__main__": main()