Featured image of post Chinesellamla 代码拆解

Chinesellamla 代码拆解

run_clm_sft_with_peft.py

先看引入的 SDK

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import logging # 加载日志系统用的
import numpy as np 
import math
import os  # 用于操作系统级别的功能,如文件路径操作
import sys  # 提供访问与Python解释器相关的变量和函数
from dataclasses import dataclass, field # 这个模块提供了一个装饰器和一些函数,用于自动为用户自定义的类添加生成的 特殊方法 例如 __init__() 和 __repr__()。
from itertools import chain # 链式迭代器 e,g: chain('ABC', 'DEF') → A B C D E F
from typing import Optional, List, Dict, Any, Mapping # 用来标识类型
from pathlib import Path # 高级的文件和路径操作工具
import datasets
import torch
from datasets import load_dataset, concatenate_datasets # 用来加载和串联数据集
from build_dataset import build_instruction_dataset, DataCollatorForSupervisedDataset  # 从自定义模块 build_dataset 中导入构建和处理训练数据的工具

# build_instruction_dataset:构建指令微调数据集的函数,通常会将原始样本转换为“输入 + 指令 + 输出”的格式,用于训练大模型进行指令跟随任务
# DataCollatorForSupervisedDataset:用于监督微调的批处理器(Data Collator),在模型训练时将样本整理成 batch,通常包括对输入进行 padding、mask 构建等处理

import transformers # 加载transformers
from transformers import (
    CONFIG_MAPPING, # 模块配置映射
    MODEL_FOR_CAUSAL_LM_MAPPING, # 处理因果语言建模(Causal LM)的模型映射 
    AutoConfig, # 自动加载配置文件
    AutoModelForCausalLM, # 自动加载因果语言建模的模型
    LlamaForCausalLM, # LLaMA 专用的因果语言建模
    LlamaTokenizer, # LLaMA 的分词器
    AutoTokenizer, # 自动加载合适的分词器
    HfArgumentParser, # 用于解析Hugging face的命令行参数
    Trainer, # 用于训练模型的工具
    TrainingArguments, # 存储参数的类
    is_torch_tpu_available, # 检查是否有TPU可用
    set_seed, # 设置随机种子
)
from transformers.testing_utils import CaptureLogger  # 捕捉并测试日志输出
from transformers.trainer_utils import get_last_checkpoint # 获取上次训练的检查点
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR # 定义检查点的目录前缀

from transformers.utils import send_example_telemetry # 发送事例信息,用于追踪
from transformers.utils.versions import require_version # 确保特定版本的库可用

from sklearn.metrics import accuracy_score # 计算模型准确率
from peft import LoraConfig, TaskType, get_peft_model, PeftModel, get_peft_model_state_dict # PEFT (Parameter Efficient Fine-Tuning)工具库,用于高效微调模型
# LoraConfig:LoRA(Low-Rank Adaptation)配置类,用于定义LoRA微调时的超参数(如r值、alpha、dropout等)
# TaskType:任务类型的枚举类,用于指定所进行的任务类型(如“CAUSAL_LM”,表示因果语言建模)
# get_peft_model:将原始模型包装成PEFT模型,注入LoRA等适配模块,返回一个可训练的PEFT模型
# PeftModel:PEFT模型的基类,所有PEFT模型的父类,包含保存、加载、推理等方法
# get_peft_model_state_dict:获取PEFT模型的状态字典(仅包含可训练的PEFT参数),用于保存和加载

再看第一部分

1
2
3
4
5
6
7
IGNORE_INDEX = -100  # 在训练时用于标记应被损失函数忽略的位置(比如被padding或未对齐的标签),如CrossEntropyLoss默认会忽略-100位置
DEFAULT_PAD_TOKEN = "[PAD]" # 用于对齐
DEFAULT_EOS_TOKEN = "</s>" # 用于回答的开始
DEFAULT_BOS_TOKEN = "<s>" # 结束
DEFAULT_UNK_TOKEN = "<unk>" # 词表未包含的词

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") # 确保版本设置正确,如果datasets版本不正确,终端返回后面的提示

SavePeftModelCallback

是在使用 PEFT(Parameter-Efficient Fine-Tuning)库 时,为了正确保存 LoRA 等微调模块专门写的一个 回调(Callback)类

回调函数是一种特殊的函数,它作为参数传递给另一个函数,并在被调用函数执行完毕后被调用。回调函数通常用于事件处理、异步编程和处理各种操作系统和框架的API。

为什么要有这个类?——答:因为 Trainer 默认会保存整个模型,但是LoRA只需要保存训练的部分

global_stepXXX 这些文件夹内的这些文件来自于 DeepSpeed 的训练过程,它们是 DeepSpeed 分布式训练中的 checkpoint 文件。每个文件的名字都体现了并行训练的维度,比如 模型并行(MP)零冗余优化器(ZeRO)并行

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
class SavePeftModelCallback(transformers.TrainerCallback):
    def save_model(self, args, state, kwargs):
        if state.best_model_checkpoint is not None:
            checkpoint_folder = os.path.join(state.best_model_checkpoint, "sft_lora_model")
        else:
            checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") # 设置保存文件夹

        peft_model_path = os.path.join(checkpoint_folder, "sft_lora_model") # 设定保存模型路径
        kwargs["model"].save_pretrained(peft_model_path) # 保存peft模型
        kwargs["tokenizer"].save_pretrained(peft_model_path) # 保存分词器

    def on_save(self, args, state, control, **kwargs):
        self.save_model(args, state, kwargs)
        return control

    def on_train_end(self, args, state, control, **kwargs):
        peft_model_path = os.path.join(args.output_dir, "sft_lora_model")
        kwargs["model"].save_pretrained(peft_model_path)
        kwargs["tokenizer"].save_pretrained(peft_model_path)

总结代码的流程

时间点 保存内容 保存路径
🧩 每次保存(on_save) PEFT 模型 + tokenizer checkpoint-xxx/sft_lora_model/
✅ 训练结束(on_train_end) 最终模型 + tokenizer output_dir/sft_lora_model/
🔄 自动执行 通过 Trainer 自动触发 无需你手动调用

ModelAugenments

设置模型的各种参数,例如: model , config, tokenizer .

现在的huggingface库里面 Tokenizer 有两种,一种就是普通的,另一种是fast的。fast和普通的区别就是fast使用rust语言编写,在处理大量文本的时候会更快。

Python 的 dataclasses 模块中,field() 是一个用于自定义数据类(dataclass)字段属性的函数。它可以设置字段的默认值、默认工厂函数(default_factory)、是否包含在比较中(compare)、是否包含在 __repr__ 中(repr),是否参与初始化(init)等等。

这个 __post_init__ 方法是 Python dataclasses 提供的一个“后初始化钩子函数”。它在 __init__ 方法执行完毕后自动调用,用于执行一些额外的初始化逻辑或参数检查。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """

    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
            )
        },
    )
    tokenizer_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The tokenizer for weights initialization.Don't set if you want to train a model from scratch."
            )
        },
    )

    config_overrides: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Override some existing default config settings when a model is trained from scratch. Example: "
                "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
            )
        },
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": (
                "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
                "with private models)."
            )
        },
    )
    torch_dtype: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
                "dtype will be automatically derived from the model's weights."
            ),
            "choices": ["auto", "bfloat16", "float16", "float32"],
        },
    )

    def __post_init__(self): # 这里的作用是在初始化后校验几个参数的合法性,这段表示如果输入了config_overrides,就不能输入config_name或者model_name_or_path 
        if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
            raise ValueError(
                "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
            )

DataTrainingArguments

preprocessing_num_workers:它是用来加速数据预处理阶段的 —— 特别是在数据量大的时候。

什么是预处理?

训练模型之前,你需要把原始文本数据转换成 模型可读的格式(token ids),这个过程叫做 预处理,通常包括:

  • 文本清洗(去空格、符号)
  • 分词(tokenization)
  • padding/truncation,也是在指令学习padding对齐的部分应该就是在预处理做的。
  • 数据格式转换(如 dict → features)

这一阶段通常是通过 datasets.Dataset.map() 方法完成的:

1
dataset.map(preprocess_function, num_proc=8) # 相比于单进程并行处理要更快,当然需要考虑硬件(CPU)的支持

这里的 num_proc=8 就是咱们说的 preprocessing_num_workers

sequence 通常指的是:一段文本在经过 tokenizer 分词之后的 token 序列

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """
		# 数据集文件夹
    dataset_dir: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
		# 训练数据文件
    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
    # 验证集文件,用于模型定期性能评估
    validation_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
    )
		# 是否覆盖缓存的数据和评估集
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    # 在没有验证拆分的情况下,用作验证集的训练数据的百分比
    validation_split_percentage: Optional[float] = field(
        default=0.05,
        metadata={
            "help": "The percentage of the train set used as validation set in case there's no validation split"
        },
    )
    # 用于预处理的进程数量
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    # 在读取txt文件时是否保留换行符
    keep_linebreaks: bool = field(
        default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
    )
    # 存储已处理的数据集地址
    data_cache_dir: Optional[str] = field(default=None, metadata={"help": "The datasets processed stored"})
		# 最大序列(sequence)长度
    max_seq_length: Optional[int] = field(default=512)

MyTrainAuguments

定义一些自己设计的参数

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
@dataclass
class MyTrainingArguments(TrainingArguments):
    trainable : Optional[str] = field(default="q_proj,v_proj") # 定义测参数的层,默认是Q和V两个输入,也就是说,只对q_proj,v_proj两个层训练,也可选k_proj和o_proj
    lora_rank : Optional[int] = field(default=8) # low-rank
    lora_dropout : Optional[float] = field(default=0.1) 
    lora_alpha : Optional[float] = field(default=32.) # 学习率
    modules_to_save : Optional[str] = field(default=None) # 额外的,需要保存的模块,此时这些模块也需要训练常用的包括 "embed_tokens", "lm_head".
    peft_path : Optional[str] = field(default=None) # 已经训练过的PEFT模型,可以再次训练或者
    force_resize_embeddings: bool = field(default=False) # 用于扩展词表,增大预训练模型的embedding size


logger = logging.getLogger(__name__) #  Python 日志系统(logging 模块)的标准写法,用来创建一个 logger(日志记录器)对象,方便你在程序中打印日志信息。 

main 函数

with training_args.main_process_first(...)

这是 Hugging Face 提供的一个 用于分布式训练环境中的上下文管理器。其作用是:

  • 在多进程环境(比如多GPU训练)中,确保只有主进程会先执行下面的代码块
  • 其余进程会等待主进程完成,比如加载和缓存数据,避免重复下载或冲突。

torch_dtype

数据类型 描述 用途说明
torch.float32 单精度浮点数(默认) 默认精度,适合大多数任务
torch.float16 半精度浮点数(FP16) 适合加速推理和节省显存(需支持)
torch.bfloat16 Brain Floating Point(BF16) 比 FP16 更稳定,用于 TPU / A100 等
"auto" / None 自动推断或使用默认 Hugging Face 会自动决定使用何种精度

getattr(torch, model_args.torch_dtype) 是干嘛的?

这是 Python 的内置函数,用来动态获取对象的属性

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
def main():
		
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, MyTrainingArguments)) # 加载huggingface参数解析器,加载上述定义的一些参数
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) # 如果参数是json文件
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses() # 参数是文本

    send_example_telemetry("run_clm", model_args, data_args) # 给Hugging face官方发送反馈信息,包含model_args和data_args,可关闭,对模型无影响,不重要

    # Setup logging,设置log的基本形式?
    logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,  # if training_args.local_rank in [-1, 0] else logging.WARN,
        handlers=[logging.StreamHandler(sys.stdout)],)


    if training_args.should_log: # 没有自定义但是属于 TrainingAuguments 里面的参数
        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
        transformers.utils.logging.set_verbosity_info() 

    log_level = training_args.get_process_log_level() 
    logger.setLevel(log_level) # 设置日志等级,应该是INFO
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler() # 为 Transformers 的日志系统启用一个默认的输出处理器(handler),一般就是标准输出(console)。
    transformers.utils.logging.enable_explicit_format() # 设置日志的显示格式为 Transformers 默认格式(带时间戳、日志级别、模块名等)。
    # transformers.tokenization_utils.logging.set_verbosity_warning()

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
      # 当output_dir是一个文件夹,do_train 为 True,并且overwrite_output_dir为False,这是加载之前的训练点的前提
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
          # 无法加载检查点,并且已经存在文件了,报错
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
           # 加载了检查点,并且传入参数允许从检查点恢复
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # Set seed before initializing model.
    set_seed(training_args.seed) 

    
    config_kwargs = {
        "cache_dir": model_args.cache_dir, # 加载模型的缓存地址,好像是直接从网上下
        "revision": model_args.model_revision, # 指定要下载的模型的版本或修订版
        "use_auth_token": True if model_args.use_auth_token else None, # 控制是否使用身份验证令牌(auth token)来访问模型。
    }
    if model_args.config_name: # 指定了相应的模型名称
        config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) # 直接从 Hugging face上面下载
    elif model_args.model_name_or_path: # model_name_or_path 可以是 Hugging Face 上的模型名称,也可以是本地路径。
        config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
    else:
      '''
      在这种情况下,config 将通过 CONFIG_MAPPING[model_args.model_type]() 创建一个新的配置实例。CONFIG_MAPPING 是一个字典,通常映射了模型类型(如 bert, gpt2 等)到相应的配置类。然后,logger.warning() 会输出一条警告,表示你正在从头开始创建一个新的配置实例。
      '''
        config = CONFIG_MAPPING[model_args.model_type]() 
        logger.warning("You are instantiating a new config instance from scratch.")
        if model_args.config_overrides is not None:
            logger.info(f"Overriding config: {model_args.config_overrides}")
            config.update_from_string(model_args.config_overrides)
            logger.info(f"New config: {config}")

    tokenizer_kwargs = {
        "cache_dir": model_args.cache_dir, # 在下载缓存的地址下面找tokenizer
        "use_fast": model_args.use_fast_tokenizer, 
        "revision": model_args.model_revision, 
        "use_auth_token": True if model_args.use_auth_token else None, 
    }
    if model_args.tokenizer_name: # 下载指定模型的tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
    elif model_args.tokenizer_name_or_path: # 路径可以是本地或者网络路径
        tokenizer = LlamaTokenizer.from_pretrained(model_args.tokenizer_name_or_path, **tokenizer_kwargs)
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if (len(tokenizer))!=49954:  # tokenizer 大小不对
        raise ValueError(f"The vocab size of the tokenizer must be 49954, but found {len(tokenizer)}.\n"
                         "Please use Chinese Alpaca tokenizer!")
    if tokenizer.pad_token is None: # 需要加入特殊字符pad_token来做指令微调
        print(f"Adding pad token {DEFAULT_PAD_TOKEN}")
        tokenizer.add_special_tokens(dict(pad_token=DEFAULT_PAD_TOKEN))
		
    # 定义一个数据收集器
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    eval_dataset=None
    train_dataset = None
		
    # 训练数据加载和预先处理
    if training_args.do_train: # 当是训练状态
        with training_args.main_process_first(desc="loading and tokenization"): # 只有在主进程才加载数据 
            path = Path(data_args.dataset_dir)
            files = [os.path.join(path,file.name) for file in path.glob("*.json")]
            logger.info(f"Training files: {' '.join(files)}")
            train_dataset = build_instruction_dataset(
                data_path=files,
                tokenizer=tokenizer,
                max_seq_length=data_args.max_seq_length,
                data_cache_dir = None,
                preprocessing_num_workers = data_args.preprocessing_num_workers)
        logger.info(f"Num train_samples  {len(train_dataset)}")
        logger.info("training example:")
        logger.info(tokenizer.decode(train_dataset[0]['input_ids']))
    if training_args.do_eval: # 验证(Evaluation)数据加载与预处理部分
        with training_args.main_process_first(desc="loading and tokenization"): # 主进程加载验证集
            files = [data_args.validation_file]
            logger.info(f"Evaluation files: {' '.join(files)}")
            eval_dataset = build_instruction_dataset(
                data_path=files,
                tokenizer=tokenizer,
                max_seq_length=data_args.max_seq_length,
                data_cache_dir = None,
                preprocessing_num_workers = data_args.preprocessing_num_workers)
        logger.info(f"Num eval_samples  {len(eval_dataset)}")
        logger.info("eval example:")
        logger.info(tokenizer.decode(eval_dataset[0]['input_ids']))
		
    # 如果是模型名称或者路径
    if model_args.model_name_or_path:
      # 设置 PyTorch 模型的权重数据类型(torch_dtype),torch_dtype 表示 PyTorch 模型中张量的数据类型(tensor data type),即模型权重加载时的精度。
        torch_dtype = (
            model_args.torch_dtype
            if model_args.torch_dtype in ["auto", None] # 未知或者自动推断
            else getattr(torch, model_args.torch_dtype) # torch_dtype = torch.float16,假设 model_args.torch_dtype 是 "float16"
        )
        # 加载一个 预训练的 LLaMA(或兼容)模型,用于 自回归语言建模(Causal Language Modeling)
        model = LlamaForCausalLM.from_pretrained( # 用于从本地或远程模型目录加载模型权重和配置。
            model_args.model_name_or_path, 
            from_tf=bool(".ckpt" in model_args.model_name_or_path), # 判断模型是否是 TensorFlow 格式(.ckpt 文件),如果是,就走 TensorFlow 到 PyTorch 的转换流程。
            config=config,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            use_auth_token=True if model_args.use_auth_token else None,
            torch_dtype=torch_dtype,
            low_cpu_mem_usage=True # 节省内存
        )
    else:
        model = AutoModelForCausalLM.from_config(config) # 若果不是,那么可能直接下载
        n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
        logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")

    logger.info(f"len(tokenizer):{len(tokenizer)}")
    embedding_size = model.get_input_embeddings().weight.shape[0] # embedding大小
    if len(tokenizer) != embedding_size: 
        logger.info("resize the embedding size by the size of the tokenizer")
        model.resize_token_embeddings(len(tokenizer))

    if training_args.peft_path is not None:
        logger.info("Peft from pre-trained model")
        model = PeftModel.from_pretrained(model, training_args.peft_path)
    else: # 没有PEFT模型直接新建
        logger.info("Init new peft model")
        target_modules = training_args.trainable.split(',')
        modules_to_save = training_args.modules_to_save
        if modules_to_save is not None:
            modules_to_save = modules_to_save.split(',')
        lora_rank = training_args.lora_rank
        lora_dropout = training_args.lora_dropout
        lora_alpha = training_args.lora_alpha
        logger.info(f"target_modules: {target_modules}")
        logger.info(f"lora_rank: {lora_rank}")
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            target_modules=target_modules,
            inference_mode=False,
            r=lora_rank, lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            modules_to_save=modules_to_save)
        model = get_peft_model(model, peft_config)

    #model.base_model.tie_weights()
    model.print_trainable_parameters()
    logger.info(f"model.modules_to_save: {model.modules_to_save}")
    old_state_dict = model.state_dict
    model.state_dict = (
        lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
    ).__get__(model, type(model))

    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )
    trainer.add_callback(SavePeftModelCallback)

    # Training,寻来呢
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
				# 记录和保存训练过程中的评估指标和状态
        metrics = train_result.metrics # 训练结果矩阵?loss, epoch, runtime, samples_per_second

        metrics["train_samples"] = len(train_dataset) # 手动添加一个自定义指标:训练样本数(train_samples),方便日后分析。

        trainer.log_metrics("train", metrics) # 训练指标输出到日志 
        trainer.save_metrics("train", metrics)
        trainer.save_state() # 保存状态

    # Evaluation
    if training_args.do_eval:
        logger.info("*** Evaluate ***")

        metrics = trainer.evaluate()
        metrics["eval_samples"] =len(eval_dataset)
        try:
            perplexity = math.exp(metrics["eval_loss"])
        except OverflowError:
            perplexity = float("inf")
        metrics["perplexity"] = perplexity

        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)


if __name__ == "__main__":
    main()

未完待续…

参考资料:

CHinese-LLaMA-Alpace

自定义文本
使用 Hugo 构建
主题 StackJimmy 设计