动漫人物头像识别

1. 概述

  • 当前目标
    输入一张动漫图片,可以自动检测图中的人物头像位置,然后进行人物识别
  • 终极目标
    当有新的人物加入时,可以用最小的代价将其加入可识别的人物列表中(动态增加新人物)

可惜暂时还没想好怎么做到终极目标

  • 总体框架
    首先使用 Yolo 网络检测头像
    再使用轻量级的图像识别网络对人物进行识别

其中 Yolo 网络使用 Yolov5,只负责检测头像,也就是在训练的时候直接将所有人物的头像当作一个类别来训练
轻量级图像识别网络使用 Resnet18 ,其以 Yolo 网络检测到的头像图片作为输入,输出识别到的人物编号

在线地址:动漫人物头像识别

2. 训练过程

2.1 数据获取

用爬虫从某个网站上借一些动漫图片(借.jpg)
分为爬取图片 URL 和爬取图片本身两部分

爬取图片 URL

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import time
from urllib import request
from urllib import error
from bs4 import BeautifulSoup


def download_raw_page(url, max_retries=2):
"""
下载整个页面的html文件
限速 0.5次爬取/秒
"""
print("Download from ", url)
html = None

try:
# 防止给服务器造成过大压力
# 设置限速,最多 2 秒发送一次爬取请求
time.sleep(2)
html = request.urlopen(url, timeout=1.0).read()
except error.HTTPError as e:
print("HTTPError: ", e.reason)
html = None
if max_retries > 0:
if hasattr(e, 'code') and (500 <= e.code < 600):
# 服务器内部错误,重试
return download_raw_page(url, max_retries - 1)
except error.URLError as e:
print("URLError: ", e.reason)

return html


def download_related_page(main_url, max_index=-1):
"""
根据主页面的URL
主页面的URL必须包含'?'
下载所有page的html文件
返回一个迭代器,每次迭代得到一个页面的html
"""

assert ("?" in main_url)

class download_related_page_iterator:
def __init__(self, main_url, max_index=10):
self.main_url = main_url
self.index = 1
self.max_index = max_index
assert (max_index > 1)

def __iter__(self):
return self

def __next__(self):
if self.index <= self.max_index:
html = download_raw_page(
self.main_url + "&page=" + str(self.index))
self.index += 1
return html
else:
raise StopIteration

# 获取当前页面,计算与这个页面相关的页面数量
html = download_raw_page(main_url)

# 检查结果是否正确
assert (html != None)

soup = BeautifulSoup(html, "html.parser")
max_index_temp = int(
soup.find(class_="pagination").find_all("a")[-2].string)

# 支持自定义最大爬取数量
if max_index == -1:
max_index = max_index_temp
else:
max_index = min(max_index_temp, max_index)

return download_related_page_iterator(main_url, max_index)


def get_img_url_from_page(html):
"""
输入html文本
返回一个列表,其中有所有所有要爬的图片的绝对地址
"""
img_links = []

soup = BeautifulSoup(html, "html.parser")
lists = soup.find(id="post-list-posts")
for img in lists.find_all(class_="directlink largeimg"):
img_links.append(img["href"])

return img_links


def get_img_url_from_related_page(main_url, max_index=-1):
"""
根据主页面的URL
主页面的URL必须包含'?'
获取所有所有要爬的图片的绝对地址
"""
img_links = []
for html in download_related_page(main_url, max_index):
img_links.extend(get_img_url_from_page(html))

return img_links


if __name__ == '__main__':


# 明日香
# character = "souryuu_asuka_langley"
# 友利奈绪
# character = "tomori_nao"
# 泽村·英梨梨
# character = "sawamura_spencer_eriri"

# 由比滨结衣
# character = "yuigahama_yui"
# 木之本樱
# character = "kinomoto_sakura"
# saber
# character = "saber"

# 艾米利亚
# character = "emilia_(re_zero)"
# 高坂桐乃
# character = "kousaka_kirino"
# 樱岛麻衣
# character = "sakurajima_mai"
# 小林托尔
character = "tooru_(kobayashi-san_chi_no_maid_dragon)"


main_url = "?????????(网站URL.jpg)" + character
max_index = 5

base_dir = "image_url/"
character_url_file = base_dir + character + ".txt"

img_links = get_img_url_from_related_page(main_url, max_index=max_index)

with open(character_url_file, 'w') as file:
for url in img_links:
file.write(url + "\n")

爬取图片本体

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import time
from urllib import request
from urllib import error
import os


def download_image(url, max_retries=2):
"""
限速 0.5次爬取/秒
"""
print("Download from ", url)
image = None

try:
# 防止给服务器造成过大压力
# 设置限速,最多 2 秒发送一次爬取请求
time.sleep(2)
image = request.urlopen(url, timeout=1.0).read()
except error.HTTPError as e:
print("HTTPError: ", e.reason)
image = None
if max_retries > 0:
if hasattr(e, 'code') and (500 <= e.code < 600):
# 服务器内部错误,重试
return download_image(url, max_retries - 1)
except OSError as e:
print("OSError: ", e.strerror)
image = None

return image


def save_image(file, content):
print("save image to ", file)
with open(file, "wb") as file:
file.write(content)


if __name__ == '__main__':

# 明日香
character = "souryuu_asuka_langley"
# 友利奈绪
# character = "tomori_nao"
# 泽村·英梨梨
# character = "sawamura_spencer_eriri"

# character = "yuigahama_yui"
# character = "kinomoto_sakura"
# character = "saber"

character_dir = "image_file/" + character + "/"
character_url_file = "image_url/" + character + ".txt"
if not os.path.exists(character_dir):
os.mkdir(character_dir)

# 遍历所有url
with open(character_url_file, "r") as url_file:
index = 0
for url in url_file.readlines():
image = download_image(url)
if image is not None:
des_file = character_dir + str.format("image_{}.jpg", index)
index += 1
save_image(des_file, image)

2.2 数据标注

  1. 首先手动删除一下不太好的图片和重复的图片
  2. 然后使用 LabelMe 对其进行标注

  3. 使用LabelmeToYolo小工具把标签的格式从 json 转换出 Yolo 格式

  4. 在训练好 Yolo 模型之后,使用 Yolo 模型把原图片头像部分截取出来(这一部分代码丢失了…)
  5. 手动把截出来的头像分类

2.3 Yolo 网络训练

参考Yolov5官方教程

2.4 ResNet18 网络训练

参考 Pytorch 网络微调教程,部分细节如下

  1. 添加数据增强

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    # 加载数据
    # 数据转换部分
    data_transforms = {
    'train': transforms.Compose([
    # 对训练数据进行数据增强
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    }
  2. 训练过程

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    history = {
    'train': {
    'acc': [],
    'loss': []
    },
    'val': {
    'acc': [],
    'loss': []
    }
    }

    for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)

    # 分为训练和验证两个阶段
    for phase in ['train', 'val']:
    if phase == 'train':
    model.train()
    else:
    model.eval()

    running_loss = 0.0
    running_corrects = 0

    # Iterate over data.
    for inputs, labels in dataloaders[phase]:
    inputs = inputs.to(device)
    labels = labels.to(device)

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward
    # track history if only in train
    with torch.set_grad_enabled(phase == 'train'):
    outputs = model(inputs)
    _, preds = torch.max(outputs, 1)
    loss = criterion(outputs, labels)

    # backward + optimize only if in training phase
    if phase == 'train':
    loss.backward()
    optimizer.step()

    # statistics
    running_loss += loss.item() * inputs.size(0)
    running_corrects += torch.sum(preds == labels.data)
    if phase == 'train':
    scheduler.step()

    epoch_loss = running_loss / dataset_sizes[phase]
    epoch_acc = running_corrects.double() / dataset_sizes[phase]

    print('{} Loss: {:.4f} Acc: {:.4f}'.format(
    phase, epoch_loss, epoch_acc))
    history[phase]['acc'].append(epoch_acc)
    history[phase]['loss'].append(epoch_loss)

    # deep copy the model
    if phase == 'val' and epoch_acc > best_acc:
    best_acc = epoch_acc
    best_model_wts = copy.deepcopy(model.state_dict())

    print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
    time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # 加载最佳参数
    model.load_state_dict(best_model_wts)
    return model,history
  3. 加载模型与训练

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    # 使用预训练了的resnet18作为识别网络
    model_ft = models.resnet18(pretrained=True)

    # 不冻结前面的参数
    # 之所以这么做,是因为也对冻结浅层参数的做法进行了测试,但是效果不如不冻结
    # 可能是因为ResNet18原本不是在动漫图上训练的原因吧

    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, len(class_names))

    model_ft = model_ft.to(device)

    criterion = nn.CrossEntropyLoss()

    # 使用随机梯度下降算法
    optimizer_ft = optim.SGD(model_ft.parameters(), lr=CONFIG_MODEL['learning_rate'], momentum=0.9)

    # 学习率每7个epoch减少10%
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

    model_ft,history = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
    num_epochs=25)
  4. 训练过程可视化

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    train_history = {
    'train': {
    'acc': list(map(Tensor.cpu, history['train']['acc'])),
    'loss': history['train']['loss']
    },
    'val': {
    'acc': list(map(Tensor.cpu, history['val']['acc'])),
    'loss': history['val']['loss']
    }
    }

    # 画图
    l1 = plt.plot(train_history['train']['acc'], linestyle=':',marker='o', label='train acc')
    l2 = plt.plot(train_history['val']['acc'], linestyle='--',marker='s', label='val acc')
    plt.xlabel('epoch')
    plt.ylabel('acc')
    plt.ylim(0,1)
    plt.title(u'准确率的变化')
    plt.legend(loc='lower right')
    # 保存图像与参数
    save_path = CONFIG_PATH['train_log_dir'] / CONFIG_MODEL['log_dir']
    if not save_path.exists():
    save_path.mkdir(parents=True)
    plt.savefig(save_path / "train_val_acc")
    np.save(save_path / "train_acc",train_history['train']['acc'])
    np.save(save_path / "val_acc",train_history['val']['acc'])

    plt.show()

    # 画图
    l1 = plt.plot(train_history['train']['loss'], linestyle=':',marker='o', label='train loss')
    l2 = plt.plot(train_history['val']['loss'], linestyle='--',marker='s', label='val loss')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.title(u'Loss的变化')
    plt.legend(loc='upper right')
    # 保存图像与参数
    save_path = CONFIG_PATH['train_log_dir'] / CONFIG_MODEL['log_dir']
    if not save_path.exists():
    save_path.mkdir(parents=True)
    plt.savefig(save_path / "train_val_loss")
    np.save(save_path / "train_loss",train_history['train']['loss'])
    np.save(save_path / "val_loss",train_history['val']['loss'])

    plt.show()
  5. 保存模型

    1
    2
    3
    4
    5
    # 对比几个 batch size 大小时的训练情况
    # batch size 小的时候泛化准确率反而很高
    # 最终选用 4 作为 batch size 来进行训练并保存模型
    model_path = CONFIG_PATH['model_save_dir'] / "resnet18.pkl"
    torch.save(model_ft.to('cpu'),model_path)

    3.网站搭建

使用 flask 搭建后台,Vue 搭建前台(当时还不太会 Vue,写的是一塌糊涂)

项目地址:头像识别网站搭建

头像识别路由代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
@app.route('/detect', methods=['POST'])
def detect():
"""识别动漫人物头像位置并识别人物

Returns:
json: {'result':[
[x,y,x,y,probability,name],
...
]}
"""
if request.method == 'POST':
# 读取传来的图片
file = request.files['image']
imagebytes = file.read()
image = Image.open(io.BytesIO(imagebytes))

boxes = __detect(image)
result = __recognize(image, boxes)

return jsonify(result)

return jsonify([
# [x,y,x,y,probability]
{
"box":[1,1,3,3,0.9],
"name": "name",
"trans": "trans"
}
])


def __detect(image: Image):
"""识别图中头像的位置

Args:
image (Image): PIL.Image对象,被识别的图片

Returns:
list: 识别到的头像框位置[(x,y,x,y,probability,class),...]
"""
boxes = yolov5(image).xyxyn[0]

# app.logger.debug(type(box))
# app.logger.debug(box)

return boxes.tolist()


def __recognize(image: Image, boxes: list):
"""识别头像人物

Args:
image (Image): 原始图片
boxes (list): 头像位置

Returns:
list: 识别到的头像框位置及人物名字[(x,y,x,y,probability,name),...]
"""
result = []
for box in boxes:
width, height = image.size
head_box = (
width * box[0], height * box[1],
width * box[2], height * box[3]
)
head_image = image.crop(head_box)

name,trans = resnet18(head_image)
result.append({
"box":box[:5],
"name":name,
"trans":trans
})

app.logger.debug(result)
return result

模型调用部分代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# 加载 yolov5 模型,选用最小的模型
yolov5 = torch.hub.load(YOLO_PATH, 'custom', path=YOLO_MODEL, source='local')
yolov5.eval()

# 加载 resnet18 模型
resnet18_model = torch.load(RESNET18_MODEL)
resnet18_model.to('cpu')
resnet18_model.eval()

class_names = []
translate = {}
# 加载标签
with open(RESNET18_LABELS,'r',encoding='utf-8') as file:
for line in file.readlines():
class_names.append(line.strip('\n'))
with open(RESNET18_TRANS, 'r', encoding='utf-8') as file:
translate = json.load(file)

transformer = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


def resnet18(image):
"""给resnet18包装一个数据预处理器
"""
temp = transformer(image)
outputs = resnet18_model(
temp[None,...]
)
_, preds = torch.max(outputs, 1)
name = class_names[preds[0]]
return name,translate[name]


app.logger.debug('acgmodel loaded')

__all__ = ['yolov5', 'resnet18']

4. 结果展示

4.1 训练过程

设置学习率为 0.0001,batch size 为 4,8,16,32 进行测试

最终采用 batch size==8 情况下训练的模型

4.2 网站搭建结果