使用链式法则求导在 MATLAB 中不起作用
Derivate using chain rule doesn't work in MATLAB
我正在尝试推导给定函数的梯度和粗麻布矩阵。当我直接做渐变时效果很好但是当我应用链式规则时它不起作用并抛出如下错误
Error using sym/diff (line 70)
Second argument must be a variable or a nonnegative integer specifying the number of differentiations.
Error in EO_a1 (line 12)
dfr = diff(f(x),r(x));
我的 MATLAB 代码
syms x a b const r(x)
const = (a*x);
r(x) = (const - b);
f(x) = (1/2)*(r(x)^2);
gradient = diff(f(x));
gradient;
hessian = diff(gradient);
hessian;
%gradient applying the chain rule
dfr = diff(f(x),r(x));
dfr;
drx = diff(dfr,x);
drx;
可能不理想,但您可以使用函数 functionalDerivative()
对函数 r(x)
求导。请注意,r(x)
必须在评估链式法则的第一部分后声明。之后,链式法则的第二部分 r'(x)
可以在使用 subs()
将符号 representation/equation 替换为 r(x)
后进行计算。在此之后,链式法则的两个部分可以相乘。在此过程之后,结果通过采用 diff(f(x))
与第一个解决方案中的步骤解决方案相匹配,该 diff(f(x))
与 x
不同,为了简洁起见,也可以通过 diff(f(x),x)
完成。
%METHOD 2: CHAIN RULE%
clear;
syms x a b const r(x)
const = (a*x);
f(x) = (1/2)*(r(x)^2);
%First part of chain rule%
dfr = functionalDerivative(f,r(x));
dfr = subs(dfr,r(x),(const - b));
%Second part of chain rule%
drx = diff(dfr,x);
%Product of parts of the chain rule%
Chain_Rule_Result = dfr*drx;
Chain_Rule_Result
运行 使用 MATLAB R2019b
我正在尝试推导给定函数的梯度和粗麻布矩阵。当我直接做渐变时效果很好但是当我应用链式规则时它不起作用并抛出如下错误
Error using sym/diff (line 70)
Second argument must be a variable or a nonnegative integer specifying the number of differentiations.
Error in EO_a1 (line 12)
dfr = diff(f(x),r(x));
我的 MATLAB 代码
syms x a b const r(x)
const = (a*x);
r(x) = (const - b);
f(x) = (1/2)*(r(x)^2);
gradient = diff(f(x));
gradient;
hessian = diff(gradient);
hessian;
%gradient applying the chain rule
dfr = diff(f(x),r(x));
dfr;
drx = diff(dfr,x);
drx;
可能不理想,但您可以使用函数 functionalDerivative()
对函数 r(x)
求导。请注意,r(x)
必须在评估链式法则的第一部分后声明。之后,链式法则的第二部分 r'(x)
可以在使用 subs()
将符号 representation/equation 替换为 r(x)
后进行计算。在此之后,链式法则的两个部分可以相乘。在此过程之后,结果通过采用 diff(f(x))
与第一个解决方案中的步骤解决方案相匹配,该 diff(f(x))
与 x
不同,为了简洁起见,也可以通过 diff(f(x),x)
完成。
%METHOD 2: CHAIN RULE%
clear;
syms x a b const r(x)
const = (a*x);
f(x) = (1/2)*(r(x)^2);
%First part of chain rule%
dfr = functionalDerivative(f,r(x));
dfr = subs(dfr,r(x),(const - b));
%Second part of chain rule%
drx = diff(dfr,x);
%Product of parts of the chain rule%
Chain_Rule_Result = dfr*drx;
Chain_Rule_Result
运行 使用 MATLAB R2019b