将 glmnet mex 与 MATLAB 结合使用时出现分段错误

Segmentation Fault when using glmnet mex with MATLAB

从我的 MATLAB 代码调用 glmnet(从此处下载:http://web.stanford.edu/~hastie/glmnet_matlab/download.html)时,我的工作因分段错误而受到影响。我调用 glmnet 例程数千次。我注意到有关问题发生的以下特点:

  1. 输入矩阵越大,问题越频繁。
  2. 我在不同的作业中同时使用高斯分布和泊松分布,我注意到在拟合泊松分布时问题更频繁(收敛通常需要更长的时间,所以可能在内部涉及更多循环?) 由于没有关于这两个发行版的 R 版本分段错误的报告,我怀疑问题可能是内存泄漏,可能出在 mex 接口而不是我粘贴在下面的核心 glmnet Fortran 代码.非常感谢任何关于可能发生内存泄漏的见解!为冗长的代码转储道歉。

谢谢!

      subroutine mexFunction(nlhs, plhs, nrhs, prhs)
C-----------------------------------------------------------------------

      mwpointer plhs(*), prhs(*)
      mwpointer mxCreateDoubleMatrix, mxGetPr, mxCreateNumericArray
      integer nlhs, nrhs
      mwsize mxGetM, mxGetN, mxGetNzmax
      integer mxIsNumeric
      integer mxIsSparse
      
C-----------------------------------------------------------------------

C     Input
      real parm,flmin,thr, intr
      integer ka,no,ni,nr,nc,ne,nx,nlam,isd,maxit,kopt,isparse,nnz,jsd
      real, dimension (:), allocatable :: x,y,w,vp,ulam,cl,sr,xs,o,d,
     $     flog,a
      integer, dimension (:), allocatable :: ix,jx,jd,irs,jcs

      mwpointer pr

C     Output
      integer lmu,nlp,jerr
      real dev0
      real, dimension (:), allocatable :: a0,ca,alm,dev,rsq
      integer, dimension (:), allocatable :: ia,nin

C     Temporary      
      mwpointer temp_pr
      mwsize temp_m, temp_n, temp_nzmax, dims(3)
      integer task,i

C     For internal parameters
      real fdev, devmax, eps, big, pmin, prec, exmx
      integer mnlam, mxit
      
C     Check for proper number of arguments.
      if (nrhs .eq. 0) then
         task = -1;
      else
         temp_pr = mxGetPr(prhs(1))
         call getinteger(temp_pr,task,1)
      endif

C     Get input

      if (task .eq. -1) then
         call get_int_parms(fdev,eps,big,mnlam,devmax,pmin,exmx)
         call get_bnorm(prec,mxit)

         plhs(1) = mxCreateDoubleMatrix(1,1,0)
         temp_pr = mxGetPr(plhs(1))
         call putreal(fdev,temp_pr,1)

         plhs(2) = mxCreateDoubleMatrix(1,1,0)
         temp_pr = mxGetPr(plhs(2))
         call putreal(devmax,temp_pr,1)

         plhs(3) = mxCreateDoubleMatrix(1,1,0)
         temp_pr = mxGetPr(plhs(3))
         call putreal(eps,temp_pr,1)

         plhs(4) = mxCreateDoubleMatrix(1,1,0)
         temp_pr = mxGetPr(plhs(4))
         call putreal(big,temp_pr,1)

         plhs(5) = mxCreateDoubleMatrix(1,1,0)
         temp_pr = mxGetPr(plhs(5))
         call putinteger(mnlam,temp_pr,1)

         plhs(6) = mxCreateDoubleMatrix(1,1,0)
         temp_pr = mxGetPr(plhs(6))
         call putreal(pmin,temp_pr,1)

         plhs(7) = mxCreateDoubleMatrix(1,1,0)
         temp_pr = mxGetPr(plhs(7))
         call putreal(exmx,temp_pr,1)

         plhs(8) = mxCreateDoubleMatrix(1,1,0)
         temp_pr = mxGetPr(plhs(8))
         call putreal(prec,temp_pr,1)

         plhs(9) = mxCreateDoubleMatrix(1,1,0)
         temp_pr = mxGetPr(plhs(9))
         call putinteger(mxit,temp_pr,1)

         return   
      endif 
      
      if (task .eq. 0) then
         temp_pr = mxGetPr(prhs(2))
         call getreal(temp_pr,fdev,1)

         temp_pr = mxGetPr(prhs(3))
         call getreal(temp_pr,devmax,1)

         temp_pr = mxGetPr(prhs(4))
         call getreal(temp_pr,eps,1)

         temp_pr = mxGetPr(prhs(5))
         call getreal(temp_pr,big,1)

         temp_pr = mxGetPr(prhs(6))
         call getinteger(temp_pr,mnlam,1)

         temp_pr = mxGetPr(prhs(7))
         call getreal(temp_pr,pmin,1)

         temp_pr = mxGetPr(prhs(8))
         call getreal(temp_pr,exmx,1)

         temp_pr = mxGetPr(prhs(9))
         call getreal(temp_pr,prec,1)

         temp_pr = mxGetPr(prhs(10))
         call getinteger(temp_pr,mxit,1)

         call chg_fract_dev(fdev)
         call chg_dev_max(devmax)
         call chg_min_flmin(eps)
         call chg_big(big)
         call chg_min_lambdas(mnlam)
         call chg_min_null_prob(pmin)
         call chg_max_exp(exmx)
         call chg_bnorm(prec, mxit)

         return
      endif

c$$$  -----------------Gaussian--------------------  
c$$$  ---input---   
      
      if (task .eq. 10 .or. task .eq. 11) then
         if (task .eq. 11) then
            temp_pr = mxGetPr(prhs(3))
            temp_m = mxGetM(prhs(3))
            no = temp_m
            temp_n = mxGetN(prhs(3))
            ni = temp_n
            allocate(x(1:no*ni))
            call getreal(temp_pr,x,no*ni)
            
         else
            temp_m = mxGetM(prhs(4))
            no = temp_m

            temp_pr = mxGetPr(prhs(3))
            temp_m = mxGetM(prhs(3))
            nnz = temp_m
            allocate(xs(1:nnz))
            call getreal(temp_pr,xs,nnz)

            temp_pr = mxGetPr(prhs(19))
            allocate(irs(1:nnz))
            call getinteger(temp_pr,irs,nnz) 

            temp_pr = mxGetPr(prhs(20))
            temp_n = mxGetM(prhs(20))
            ni = temp_n - 1
            allocate(jcs(1:(ni+1)))
            call getinteger(temp_pr,jcs,(ni+1)) 
         endif

         temp_pr = mxGetPr(prhs(2))
         call getreal(temp_pr,parm,1)

         temp_pr = mxGetPr(prhs(4))
         allocate(y(1:no))
         call getreal(temp_pr,y,no)
         
         temp_pr = mxGetPr(prhs(5))
         temp_m = mxGetM(prhs(5))
         temp_n = mxGetN(prhs(5))
         allocate(jd(temp_m*temp_n))
         call getinteger(temp_pr,jd,temp_m*temp_n)     
         
         temp_pr = mxGetPr(prhs(6))
         allocate(vp(1:ni))
         call getreal(temp_pr,vp,ni)

         temp_pr = mxGetPr(prhs(7))
         call getinteger(temp_pr,ne,1)

         temp_pr = mxGetPr(prhs(8))
         call getinteger(temp_pr,nx,1)
         
         temp_pr = mxGetPr(prhs(9))
         call getinteger(temp_pr,nlam,1)

         temp_pr = mxGetPr(prhs(10))
         call getreal(temp_pr,flmin,1)     
         
         temp_pr = mxGetPr(prhs(11))
         temp_m = mxGetM(prhs(11))
         temp_n = mxGetN(prhs(11))
         allocate(ulam(1:temp_m * temp_n))
         call getreal(temp_pr,ulam,temp_m * temp_n)
         
         temp_pr = mxGetPr(prhs(12))
         call getreal(temp_pr,thr,1)
         
         temp_pr = mxGetPr(prhs(13))
         call getinteger(temp_pr,isd,1)

         temp_pr = mxGetPr(prhs(14))
         allocate(w(1:no))
         call getreal(temp_pr,w,no)

         temp_pr = mxGetPr(prhs(15))
         call getinteger(temp_pr,ka,1)

         temp_pr = mxGetPr(prhs(16))
         allocate(cl(1:2*ni))
         call getreal(temp_pr,cl,2*ni)

         temp_pr = mxGetPr(prhs(17))
         call getinteger(temp_pr,intr,1)

         temp_pr = mxGetPr(prhs(18))
         call getinteger(temp_pr,maxit,1)       

c$$$  ---prepare output---

         allocate(ia(1:nx))
         call zerointeger(ia,nx)
         allocate(nin(1:nlam))
         call zerointeger(nin,nlam)
         allocate(alm(1:nlam))
         call zeroreal(alm,nlam)
         allocate(a0(1:nlam))
         call zeroreal(a0,nlam)         
         allocate(ca(1:nx*nlam))
         call zeroreal(ca,nx*nlam)
         allocate(rsq(1:nlam))
         call zeroreal(rsq,nlam)


c$$$  ---computation----

         if (task .eq. 11) then    
            call elnet(ka,parm,no,ni,x,y,w,jd,vp,cl,ne,nx,nlam,flmin,
     $           ulam,thr,isd,intr,maxit,lmu,a0,ca,ia,nin,rsq,alm,
     $           nlp,jerr)
         else
            call spelnet(ka,parm,no,ni,xs,jcs,irs,y,w,jd,vp,cl,ne,nx,
     $           nlam,flmin,ulam,thr,isd,intr,maxit,lmu,a0,ca,ia,nin,
     $           rsq,alm,nlp,jerr)
         endif

c$$$  ----output-----

         plhs(1) = mxCreateDoubleMatrix(1,1,0)
         temp_pr = mxGetPr(plhs(1))
         call putinteger(lmu,temp_pr,1)

         plhs(4) = mxCreateDoubleMatrix(nx,1,0)
         temp_pr = mxGetPr(plhs(4))
         call putinteger(ia,temp_pr,nx)
         
         plhs(5) = mxCreateDoubleMatrix(lmu,1,0)
         temp_pr = mxGetPr(plhs(5))
         call putinteger(nin,temp_pr,lmu)
         
         plhs(7) = mxCreateDoubleMatrix(lmu,1,0)
         temp_pr = mxGetPr(plhs(7))
         call putreal(alm,temp_pr,lmu)
         
         plhs(8) = mxCreateDoubleMatrix(1,1,0)
         temp_pr = mxGetPr(plhs(8))
         call putinteger(nlp,temp_pr,1)
         
         plhs(9) = mxCreateDoubleMatrix(1,1,0)
         temp_pr = mxGetPr(plhs(9))
         call putinteger(jerr,temp_pr,1)

         plhs(2) = mxCreateDoubleMatrix(lmu,1,0)
         temp_pr = mxGetPr(plhs(2))
         call putreal(a0,temp_pr,lmu)   
         
         plhs(3) = mxCreateDoubleMatrix(nx,lmu,0)
         temp_pr = mxGetPr(plhs(3))
         call putreal(ca,temp_pr,nx*lmu)  
         
         plhs(6) = mxCreateDoubleMatrix(lmu,1,0)
         temp_pr = mxGetPr(plhs(6))
         call putreal(rsq,temp_pr,lmu)  

         deallocate(y)
         deallocate(jd)
         deallocate(vp)
         deallocate(ulam)
         deallocate(a0)
         deallocate(ca)
         deallocate(ia)
         deallocate(nin)
         deallocate(alm)
         deallocate(w)
         deallocate(rsq)
         deallocate(cl)
         
         if (task .eq. 11) then
            deallocate(x)
         else
            deallocate(xs)
            deallocate(irs)
            deallocate(jcs)  
         endif
         return
      endif      
c$$$  --------------end of Gaussian---------------------------
c$$$  ---------------Poisson--------------------------
c$$$  ---input---   
      
      if (task .eq. 50 .or. task .eq. 51) then
         if (task .eq. 51) then
            temp_pr = mxGetPr(prhs(3))
            temp_m = mxGetM(prhs(3))
            no = temp_m
            temp_n = mxGetN(prhs(3))
            ni = temp_n
            allocate(x(1:no*ni))
            call getreal(temp_pr,x,no*ni)
            
         else
            temp_m = mxGetM(prhs(4))
            no = temp_m

            temp_pr = mxGetPr(prhs(3))
            temp_m = mxGetM(prhs(3))
            nnz = temp_m
            allocate(xs(1:nnz))
            call getreal(temp_pr,xs,nnz)

            temp_pr = mxGetPr(prhs(19))
            allocate(irs(1:nnz))
            call getinteger(temp_pr,irs,nnz) 

            temp_pr = mxGetPr(prhs(20))
            temp_n = mxGetM(prhs(20))
            ni = temp_n - 1
            allocate(jcs(1:(ni+1)))
            call getinteger(temp_pr,jcs,(ni+1)) 
         endif

         temp_pr = mxGetPr(prhs(2))
         call getreal(temp_pr,parm,1)

         temp_pr = mxGetPr(prhs(4))
         allocate(y(1:no))
         call getreal(temp_pr,y,no)
         
         temp_pr = mxGetPr(prhs(5))
         temp_m = mxGetM(prhs(5))
         temp_n = mxGetN(prhs(5))
         allocate(jd(temp_m*temp_n))
         call getinteger(temp_pr,jd,temp_m*temp_n)     
         
         temp_pr = mxGetPr(prhs(6))
         allocate(vp(1:ni))
         call getreal(temp_pr,vp,ni)

         temp_pr = mxGetPr(prhs(7))
         call getinteger(temp_pr,ne,1)

         temp_pr = mxGetPr(prhs(8))
         call getinteger(temp_pr,nx,1)
         
         temp_pr = mxGetPr(prhs(9))
         call getinteger(temp_pr,nlam,1)

         temp_pr = mxGetPr(prhs(10))
         call getreal(temp_pr,flmin,1)     
         
         temp_pr = mxGetPr(prhs(11))
         temp_m = mxGetM(prhs(11))
         temp_n = mxGetN(prhs(11))
         allocate(ulam(1:temp_m * temp_n))
         call getreal(temp_pr,ulam,temp_m * temp_n)
         
         temp_pr = mxGetPr(prhs(12))
         call getreal(temp_pr,thr,1)
         
         temp_pr = mxGetPr(prhs(13))
         call getinteger(temp_pr,isd,1)

         temp_pr = mxGetPr(prhs(14))
         allocate(w(1:no))
         call getreal(temp_pr,w,no)

         temp_pr = mxGetPr(prhs(15))
         allocate(cl(1:2*ni))
         call getreal(temp_pr,cl,2*ni)

         temp_pr = mxGetPr(prhs(16))
         call getinteger(temp_pr,intr,1)

         temp_pr = mxGetPr(prhs(17))
         call getinteger(temp_pr,maxit,1)

         temp_pr = mxGetPr(prhs(18))
         allocate(o(1:no))
         call getreal(temp_pr,o,no)

c$$$  ---prepare output---

         allocate(ia(1:nx))
         call zerointeger(ia,nx)
         allocate(nin(1:nlam))
         call zerointeger(nin,nlam)
         allocate(alm(1:nlam))
         call zeroreal(alm,nlam)
         allocate(a0(1:nlam))
         call zeroreal(a0,nlam)
         allocate(ca(1:nx*nlam))
         call zeroreal(ca,nx*nlam)
         allocate(dev(1:nlam))
         call zeroreal(dev,nlam)

c$$$  ---computation----

         if (task .eq. 51) then    
            call fishnet(parm,no,ni,x,y,o,w,jd,vp,cl,ne,nx,nlam,flmin,
     $           ulam,thr,isd,intr,maxit,lmu,a0,ca,ia,nin,dev0,dev,alm,
     $           nlp,jerr)
         else
            call spfishnet(parm,no,ni,xs,jcs,irs,y,o,w,jd,vp,cl,ne,nx,
     $           nlam,flmin,ulam,thr,isd,intr,maxit,lmu,a0,ca,ia,
     $           nin,dev0,dev,alm,nlp,jerr)
         endif

c$$$  ----output-----

         plhs(1) = mxCreateDoubleMatrix(1,1,0)
         temp_pr = mxGetPr(plhs(1))
         call putinteger(lmu,temp_pr,1)

         plhs(4) = mxCreateDoubleMatrix(nx,1,0)
         temp_pr = mxGetPr(plhs(4))
         call putinteger(ia,temp_pr,nx)
         
         plhs(5) = mxCreateDoubleMatrix(lmu,1,0)
         temp_pr = mxGetPr(plhs(5))
         call putinteger(nin,temp_pr,lmu)
         
         plhs(7) = mxCreateDoubleMatrix(lmu,1,0)
         temp_pr = mxGetPr(plhs(7))
         call putreal(alm,temp_pr,lmu)
         
         plhs(8) = mxCreateDoubleMatrix(1,1,0)
         temp_pr = mxGetPr(plhs(8))
         call putinteger(nlp,temp_pr,1)
         
         plhs(9) = mxCreateDoubleMatrix(1,1,0)
         temp_pr = mxGetPr(plhs(9))
         call putinteger(jerr,temp_pr,1)

         plhs(2) = mxCreateDoubleMatrix(lmu,1,0)
         temp_pr = mxGetPr(plhs(2))
         call putreal(a0,temp_pr,lmu)

         plhs(3) = mxCreateDoubleMatrix(nx,lmu,0)
         temp_pr = mxGetPr(plhs(3))
         call putreal(ca,temp_pr,nx*lmu)
         
         plhs(6) = mxCreateDoubleMatrix(lmu,1,0)
         temp_pr = mxGetPr(plhs(6))
         call putreal(dev,temp_pr,lmu)

         plhs(10) = mxCreateDoubleMatrix(1,1,0)
         temp_pr = mxGetPr(plhs(10))
         call putreal(dev0,temp_pr,1)

         plhs(11) = mxCreateDoubleMatrix(no,1,0)
         temp_pr = mxGetPr(plhs(11))
         call putreal(o,temp_pr,no)

         deallocate(y)
         deallocate(jd)
         deallocate(vp)
         deallocate(ulam)
         deallocate(a0)
         deallocate(ca)
         deallocate(ia)
         deallocate(nin)
         deallocate(alm)
         deallocate(cl)
         deallocate(o)
         deallocate(dev)
         
         if (task .eq. 51) then
            deallocate(x)
         else
            deallocate(xs)
            deallocate(irs)
            deallocate(jcs)  
         endif
         return
      endif

c$$$  --------------------end of Poisson------------------

      return
      end

C     End of subroutine mexFunction
      
      subroutine real8toreal(x, y, size)
      integer size
      real*8 x(size)
      real y(size)
      do 10 i=1,size
         y(i)= x(i)
 10   continue
      return
      end

      subroutine realtoreal8(x, y, size)
      integer size
      real x(size)
      real*8 y(size)
      do 20 i=1,size
         y(i)= x(i)
 20   continue
      return
      end
      
      subroutine real8tointeger(x, y, size)
      integer size
      real*8 x(size)
      integer y(size)
      do 30 i=1,size
         y(i)= x(i)
 30   continue
      return
      end
      
      subroutine integertoreal8(x, y, size)
      integer size
      integer x(size)
      real*8 y(size)
      do 40 i=1,size
         y(i)= x(i)
 40   continue
      return
      end
      
      subroutine getreal(pr,x,size)
      mwpointer pr
      integer size
      real x(size)
      real*8, dimension (:), allocatable :: temp
      allocate(temp(1:size))
      call mxCopyPtrToReal8(pr,temp,size)
      call real8toreal(temp,x,size)
      deallocate(temp)      
      return
      end
      
      subroutine getinteger(pr,x,size)
      mwpointer pr
      integer size
      integer x(size)
      real*8, dimension (:), allocatable :: temp
      allocate(temp(1:size))
      call mxCopyPtrToReal8(pr,temp,size)
      call real8tointeger(temp,x,size)
      deallocate(temp)      
      return
      end      
      
      subroutine putreal(x,pr,size)
      mwpointer pr
      integer size
      real x(size)
      real*8, dimension (:), allocatable :: temp
      allocate(temp(1:size))
      call realtoreal8(x,temp,size)
      call mxCopyReal8ToPtr(temp,pr,size)
      deallocate(temp)      
      return
      end
      
      subroutine putinteger(x,pr,size)
      mwpointer pr
      integer size
      integer x(size)
      real*8, dimension (:), allocatable :: temp
      allocate(temp(1:size))
      call integertoreal8(x,temp,size)
      call mxCopyReal8ToPtr(temp,pr,size)
      deallocate(temp)      
      return
      end            
      
      subroutine zeroreal(x,size)
      integer size
      real x(size)
      do 90 i=1,size
         x(i) = 0
 90   continue     
      return 
      end
      
      subroutine zerointeger(x,size)
      integer size
      integer x(size)
      do 100 i=1,size
         x(i) = 0
 100  continue
      return
      end

我要做的第一件事是清理 MATLAB API 界面的东西。请记住,在 Fortran 中,您不会像在 C/C++ 中那样在 function/subroutine 参数列表中获得自动类型提升。因此,获得准确的签名很重要。您永远不应将文字整数传递给 MATLAB API 函数。您应该传递完全按照 API 指定键入的变量,以确保没有不匹配。例如,采用此代码:

  subroutine getreal(pr,x,size)
  mwpointer pr
  integer size
  real x(size)
  real*8, dimension (:), allocatable :: temp
  allocate(temp(1:size))
  call mxCopyPtrToReal8(pr,temp,size)
  call real8toreal(temp,x,size)
  deallocate(temp)      
  return
  end

API 中 mxCopyPtrToReal8 的签名是:

  subroutine mxCopyPtrToReal8(px, y, n)
  mwPointer px
  real*8 y(n)
  mwSize n

因此您可能存在不匹配,因为默认的 Fortran 整数可能与 mwSize 不匹配。此外,size 是 Fortran 内部函数的名称,因此为您的变量取一个不同的名称可能更合适。
我会将该子例程更改为:

  subroutine getreal(pr,x,sizex)
  mwpointer pr
  mwSize sizex
  real x(sizex)
  real*8, dimension (:), allocatable :: temp
  allocate(temp(1:sizex))
  call mxCopyPtrToReal8(pr,temp,sizex)
  call real8toreal(temp,x,sizex)
  deallocate(temp)      
  return
  end

现在可以确保 sizex 是合适的类型。您还需要更改调用例程中变量的类型。

(旁注:实际上,我不会做任何你正在做的事情......我只是写一个循环来将值直接从 mxArray 复制到你的真实数组中,而不需要额外的副本和内存 allocation/deallocation)

另一个例子是这样的:

  integer ...,nx,...
      :
  integer lmu,...
       :
     plhs(3) = mxCreateDoubleMatrix(nx,lmu,0)

应替换为:

  mwSize nx, lmu
  integer*4 :: ComplexFlag = 0
       :
     plhs(3) = mxCreateDoubleMatrix(nx,lmu,ComplexFlag)

而且,坦率地说,您有很多可以用简单语句替换的赋值循环。例如,

  call real8toreal(temp,x,sizex)

可以替换为:

  x = temp

还有这个:

 allocate(ia(1:nx))
 call zerointeger(ia,nx)
 allocate(nin(1:nlam))
 call zerointeger(nin,nlam)
 allocate(alm(1:nlam))
 call zeroreal(alm,nlam)
 allocate(a0(1:nlam))
 call zeroreal(a0,nlam)
 allocate(ca(1:nx*nlam))
 call zeroreal(ca,nx*nlam)
 allocate(dev(1:nlam))
 call zeroreal(dev,nlam)

可以这样替换:

 allocate(ia(1:nx))
 ia = 0
 allocate(nin(1:nlam))
 nin = 0
 allocate(alm(1:nlam))
 alm = 0.0
 allocate(a0(1:nlam))
 a0 = 0.0
 allocate(ca(1:nx*nlam))
 ca = 0.0
 allocate(dev(1:nlam))
 dev = 0.0

等等