作者|李秋琪
书| Carol
制作| AI技术大本营(ID:rgznai100)
最近几天,一个GitHub项目放火烧了整个朋友圈。就是漫画头像AI生成小程序。
如下图所见:而这个项目的基本原理是用Python搭建的GAN算法模型,进行训练得出。
而所谓的GAN就是指生成对抗网络深度学习模型。网络中有生成器G(generator)和鉴别器(Discriminator)。有两个数据域分别为X,Y。G 负责把X域中的数据拿过来拼命地模仿成真实数据并把它们藏在真实数据中,而 D 就拼命地要把伪造数据和真实数据分开。经过二者的博弈以后,G 的伪造技术越来越厉害,D 的鉴别技术也越来越厉害。直到 D 再也分不出数据是真实的还是 G 生成的数据的时候,这个对抗的过程达到一个动态的平衡。
而CycleGAN本质上是两个镜像对称的GAN,构成了一个环形网络。
两个GAN共享两个生成器,并各自带一个判别器,即共有两个判别器和两个生成器。一个单向GAN两个loss,两个即共四个loss。
可以实现无配对的两个图片集的训练是CycleGAN与Pixel2Pixel相比的一个典型优点。但是我们仍然需要通过训练创建这个映射来确保输入图像和生成图像间存在有意义的关联,即输入输出共享一些特征。
简而言之,该模型通过从域DA获取输入图像,该输入图像被传递到第一个生成器GeneratorA→B,其任务是将来自域DA的给定图像转换到目标域DB中的图像。然后这个新生成的图像被传递到另一个生成器GeneratorB→A,其任务是在原始域DA转换回图像,这里可与自动编码器作对比。这个输出图像必须与原始输入图像相似,用来定义非配对数据集中原来不存在的有意义映射。
在本次的项目中就是利用了CycleGAN进行搭建模型。模型训练数据集如下:
实验前的准备
首先我们使用的python版本是3.6.5所用到的库有pytorch和TensorFlow,用来训练和加载神经网络常见的框架;face-alignment用来是用来提取人脸特征的常用库;
dlib是一个机器学习的开源库,包含了机器学习的很多算法,使用起来很方便,直接包含头文件即可,并且不依赖于其他库(自带图像编解码库源码)。Dlib可以帮助您创建很多复杂的机器学习方面的软件来帮助解决实际问题。目前Dlib已经被广泛的用在行业和学术领域,包括机器人,嵌入式设备,移动电话和大型高性能计算环境。
模型的训练
1、数据集处理和准备:
训练数据包括真实照片和卡通画像,为降低训练复杂度,我们对两类数据进行了如下预处理:
· 检测人脸及关键点。
· 根据关键点旋转校正人脸。
· 将关键点边界框按固定的比例扩张并裁剪出人脸区域。
· 使用人像分割模型将背景置白。
为了形成匹配效果,需要准备一些卡通人物图片和真实的人脸图片进行训练
2、模型的训练:
模型的训练使用python --dataset photo2cartoon进行训练即可。
3、神经网络结构搭建:
整个算法的搭建正如上面可见,需要有生成器和判别器。使用论文提出的一种Soft-AdaLIN(Soft Adaptive Layer-Instance Normalization)归一化方法,在反规范化时将编码器的均值方差(照片特征)与解码器的均值方差(卡通特征)相融合。
模型结构方面,在U-GAT-IT的基础上,在编码器之前和解码器之后各增加了2个hourglass模块,渐进地提升模型特征抽象和重建能力。
部分代码如下:
class ResnetGenerator):def __init__(self, ngf=64, img_size=256, light=False):
super(ResnetGenerator, self).__init__
= light
= nn.Sequential(3),
nn.Conv2d(3, ngf, kernel_size=7, stride=1, padding=0, bias=False),
nn.InstanceNorm2d(ngf),
nn.ReLU(True))
= HourGlass(ngf, ngf)
= HourGlass(ngf, ngf)
# Down-Sampling
= nn.Sequential(1),
nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=2, padding=0, bias=False),
nn.InstanceNorm2d(ngf * 2),
nn.ReLU(True))
= nn.Sequential(1),
nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=2, padding=0, bias=False),
nn.InstanceNorm2d(ngf*4),
nn.ReLU(True))
# Encoder Bottleneck
= ResnetBlock(ngf*4)
= ResnetBlock(ngf*4)
= ResnetBlock(ngf*4)
= ResnetBlock(ngf*4)
# Class Activation Map
= nn.Linear(ngf*4, 1)
= nn.Linear(ngf*4, 1)
= nn.Conv2d(ngf*8, ngf*4, kernel_size=1, stride=1)
= nn.ReLU(True)
# Gamma, Beta block
if :
= nn.Sequential(ngf*4, ngf*4),
nn.ReLU(True),
nn.Linear(ngf*4, ngf*4),
nn.ReLU(True))
else:
= nn.Sequential(img_size//4*img_size//4*ngf*4, ngf*4),
nn.ReLU(True),
nn.Linear(ngf*4, ngf*4),
nn.ReLU(True))
# Decoder Bottleneck
= ResnetSoftAdaLINBlock(ngf*4)
= ResnetSoftAdaLINBlock(ngf*4)
= ResnetSoftAdaLINBlock(ngf*4)
= ResnetSoftAdaLINBlock(ngf*4)
# Up-Sampling
= nn.Sequential(scale_factor=2),
nn.ReflectionPad2d(1),
nn.Conv2d(ngf*4, ngf*2, kernel_size=3, stride=1, padding=0, bias=False),
LIN(ngf*2),
nn.ReLU(True))
= nn.Sequential(scale_factor=2),
nn.ReflectionPad2d(1),
nn.Conv2d(ngf*2, ngf, kernel_size=3, stride=1, padding=0, bias=False),
LIN(ngf),
nn.ReLU(True))
= HourGlass(ngf, ngf)
= HourGlass(ngf, ngf, False)
= nn.Sequential(3),
nn.Conv2d(3, 3, kernel_size=7, stride=1, padding=0, bias=False),
nn.Tanh)
def forward(self, x):
x = (x)
x = (x)
x = (x)
x = (x)
x = (x)
x = (x)
content_features1 = F.adaptive_avg_pool2d(x, 1).view[0], -1)
x = (x)
content_features2 = F.adaptive_avg_pool2d(x, 1).view[0], -1)
x = (x)
content_features3 = F.adaptive_avg_pool2d(x, 1).view[0], -1)
x = (x)
content_features4 = F.adaptive_avg_pool2d(x, 1).view[0], -1)
gap = F.adaptive_avg_pool2d(x, 1)
gap_logit = [0], -1))
gap_weight = list(.parameters)[0]
gap = x * ga(2).unsqueeze(3)
gmp = F.adaptive_max_pool2d(x, 1)
gmp_logit = [0], -1))
gmp_weight = list(.parameters)[0]
gmp = x * gm(2).unsqueeze(3)
cam_logit = ([gap_logit, gmp_logit], 1)
x = ([gap, gmp], 1)
x = ((x))
heatmap = (x, dim=1, keepdim=True)
if :
x_ = F.adaptive_avg_pool2d(x, 1)
style_features = [0], -1))
else:
style_features = [0], -1))
x = (x, content_features4, style_features)
x = (x, content_features3, style_features)
x = (x, content_features2, style_features)
x = (x, content_features1, style_features)
x = (x)
x = (x)
x = (x)
x = (x)
out = (x)
return out, cam_logit, heatmap
4、提取人脸特征:
为了提取人脸特征以达到加载到网络中的目的,我们需要正确框出人脸同时计算特征距离,以方便后面训练模型师损失函数的调用。
代码如下:
class FaceFeatures(object):def __init__(self, weights_path, device):
= device
= MobileFaceNet(512).to(device)
.load_state_dict(weights_path))
.eval
def infer(self, batch_tensor):
# crop face
h, w = ba[2:]
top = int(h / 2.1 * - 0.33))
bottom = int(h - (h / 2.1 * 0.3))
size = bottom - top
left = int(w / 2 - size / 2)
right = left + size
batch_tensor = batch_tensor[:, :, top: bottom, left: right]
batch_tensor = F.interpolate(batch_tensor, size=[112, 112], mode='bilinear', align_corners=True)
features = (batch_tensor)
return features
def cosine_distance(self, batch_tensor1, batch_tensor2):
feature1 = (batch_tensor1)
feature2 = (batch_tensor2)
return 1 - (feature1, feature2)
模型测试
在训练好模型后,我们使用python --photo_path ./image --save_path ./image测试生成图片。其中1.jpg是原始图片,最终会生成2.jpg图片。
使用python da --data_path YourPhotoFolderPath --save_path YourSaveFolderPath批量生成
1、调用模型:
调用模型首先要使用torch进行加载模型,读取神经网络参数。在对原始图片提取人脸特征的基础上,加载进网络进行生成即可。因为这里我们还需要对生成的数据进行转换成图片,我们这里还需要使用numpy和opencv进行图片的转化。因为加载如模型和模型生成的必然是数据,而我们需要将生成器产生的数据再转换为图片,就用到了这两个库。
代码如下:
class Photo2cartoon:def __init__(self):
= Preprocess
= ("cuda:0" if else "cpu")
= ResnetGenerator(ngf=32, img_size=256, light=True).to()
params = ('./model;, map_location=)
.load_state_dict(params['genA2B'])
def inference(self, img):
# face alignment and segmentation
face_rgba = .process(img)
if face_rgba is None:
print('can not detect face!!!')
return None
face_rgba = cv2.resize(face_rgba, (256, 256), interpolation=cv2.INTER_AREA)
face = face_rgba[:, :, :3].copy
mask = face_rgba[:, :, 3][:, :, np.newaxis].copy / 255.
face = (face*mask + (1-mask)*255) / 127.5 - 1
face = np.transpose(face[np.newaxis, :, :, :], (0, 3, 1, 2)).astype)
face = (face).to()
# inference
with :
cartoon = (face)[0][0]
# post-process
cartoon = np.transpose, (1, 2, 0))
cartoon = (cartoon + 1) * 127.5
cartoon = (cartoon * mask + 255 * (1 - mask)).astype)
cartoon = cv2.cvtColor(cartoon, cv2.COLOR_RGB2BGR)
return cartoon
if __name__ == '__main__':
img = cv2.cvtColor), cv2.COLOR_BGR2RGB)
c2p = Photo2Cartoon
cartoon = c2p.inference(img)
if cartoon is not None:
cv2.imwrite, cartoon)
到这里,我们整体的程序就搭建完成,下面为我们程序的运行结果:
在这里附上源码地址:
链接:
提取码:54vp
作者简介:
李秋键,CSDN 博客专家,CSDN达人课作者。硕士在读于中国矿业大学,开发有taptap安卓武侠游戏一部,vip视频解析,文意转换工具,写作机器人等项目,发表论文若干,多次高数竞赛获奖等等。
☞全球 Python 调查报告:Python 2 正在消亡,PyCharm 比 VS Code 更受欢迎!
☞雷军喜提第四家上市公司;梨视频 App 被全网下架;Flutter 1.17 稳定版发布 | 极客头条
☞微服务太杂乱难以管理?一站式服务治理平台来袭!
☞开源一年,阿里轻量级AI推理引擎MNN 1.0.0正式发布
☞Redis 6.0 新特性:多线程连环 13 问!
☞从技术原理解析区块链为何列入新基建