Python 中的 SVD 图像重建
SVD image reconstruction in Python
我正在尝试对这张图片进行奇异值分解:
取前 10 个值。我有这个代码:
from PIL import Image
import numpy as np
img = Image.open('bee.jpg')
img = np.mean(img, 2)
U,s,V = np.linalg.svd(img)
recon_img = U @ s[1:10] @ V
但是当我 运行 它抛出这个错误:
ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 9 is different from 819)
所以我认为我在重建时做错了什么。我不确定 np.linalg.svd(img)
创建的矩阵的维度。
我该如何解决?
对不起英语
问题是 s
的维度,如果您打印 U
、s
和 V
维度,我得到:
print(np.shape(U))
print(np.shape(s))
print(np.shape(V))
(819, 819)
(819,)
(1024, 1024)
所以U
和V
是方阵,s
是数组。您必须创建一个与图像尺寸相同的矩阵 (819 x 1024),主对角线上 s
为:
n = 10
S = np.zeros(np.shape(img))
for i in range(0, n):
S[i,i] = s[i]
print(np.shape(S))
输出:
(819, 1024)
那你就可以继续你的阐述了。为了进行比较,请查看此代码:
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
img = Image.open('bee.jpg')
img = np.mean(img, 2)
U,s,V = np.linalg.svd(img)
n = 10
S = np.zeros(np.shape(img))
for i in range(0, n):
S[i,i] = s[i]
recon_img = U @ S @ V
fig, ax = plt.subplots(1, 2)
ax[0].imshow(img)
ax[0].axis('off')
ax[0].set_title('Original')
ax[1].imshow(recon_img)
ax[1].axis('off')
ax[1].set_title(f'Reconstructed n = {n}')
plt.show()
给我这个:
我正在尝试对这张图片进行奇异值分解:
取前 10 个值。我有这个代码:
from PIL import Image
import numpy as np
img = Image.open('bee.jpg')
img = np.mean(img, 2)
U,s,V = np.linalg.svd(img)
recon_img = U @ s[1:10] @ V
但是当我 运行 它抛出这个错误:
ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 9 is different from 819)
所以我认为我在重建时做错了什么。我不确定 np.linalg.svd(img)
创建的矩阵的维度。
我该如何解决?
对不起英语
问题是 s
的维度,如果您打印 U
、s
和 V
维度,我得到:
print(np.shape(U))
print(np.shape(s))
print(np.shape(V))
(819, 819)
(819,)
(1024, 1024)
所以U
和V
是方阵,s
是数组。您必须创建一个与图像尺寸相同的矩阵 (819 x 1024),主对角线上 s
为:
n = 10
S = np.zeros(np.shape(img))
for i in range(0, n):
S[i,i] = s[i]
print(np.shape(S))
输出:
(819, 1024)
那你就可以继续你的阐述了。为了进行比较,请查看此代码:
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
img = Image.open('bee.jpg')
img = np.mean(img, 2)
U,s,V = np.linalg.svd(img)
n = 10
S = np.zeros(np.shape(img))
for i in range(0, n):
S[i,i] = s[i]
recon_img = U @ S @ V
fig, ax = plt.subplots(1, 2)
ax[0].imshow(img)
ax[0].axis('off')
ax[0].set_title('Original')
ax[1].imshow(recon_img)
ax[1].axis('off')
ax[1].set_title(f'Reconstructed n = {n}')
plt.show()
给我这个: