import sys
import os
import glob
import cv2
import numpy as np
import torch
import architecture as arch
import traceback
model_path = sys.argv[1] # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
device = torch.device('cuda') # if you want to run on CPU, change 'cuda' -> cpu
# device = torch.device('cpu')
import_folder = 'C:/FilesToProcess/In'
export_folder = 'C:/FilesToProcess/Out'
fail_log_path = os.path.join(export_folder, '_failureLog.txt')
consecutive_failures_to_abort = 5
model = arch.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
mode='CNA', res_scale=1, upsample_mode='upconv')
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
print('Model path {:s}. \nTesting...'.format(model_path))
def process(torch, model, import_folder, export_folder):
# read image
img = cv2.imread(path, cv2.IMREAD_COLOR)
img = img * 1.0 / 255
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img_LR = img.unsqueeze(0)
img_LR = img_LR.to(device)
output = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round()
export_path = export_folder + path[len(import_folder):]
export_path_dir = os.path.dirname(os.path.abspath(export_path))
if not os.path.exists(export_path_dir):
os.makedirs(export_path_dir)
cv2.imwrite(export_path, output)
idx = 0
consecutive_failures = 0
fail_log = None
for path in glob.glob(import_folder + '/**/*.*', recursive=True):
idx += 1
print(idx, path)
try:
process(torch, model, import_folder, export_folder)
consecutive_failures = 0
except Exception as e:
print('Failed to process ({0})'.format(type(e).__name__))
fail_log = open(fail_log_path, 'a')
fail_log.write('***Failed image: ' + path)
fail_log.write('Error: ' + str(e))
fail_log.write(traceback.format_exc())
consecutive_failures += 1
if (consecutive_failures >= consecutive_failures_to_abort):
raise (RuntimeError("Maximum consecutive failures reached, aborting processing"))
finally:
if (fail_log):
fail_log.close()
fail_log = None