简化代码以找到最后一次出现的值

simplify the code to find the last occurrence of a value

我有一个包含值 0 和 1 的四维数组 [时间、型号、经度、纬度]。我想在该数组中找到关于时间序列的最后一个零位置(最后一年时间为零。我想对 [经度、纬度、型号] 的整个时间序列执行此操作,并返回一个 3D 数组。

我的代码计算时间很长,有没有其他方法可以做到这一点?

element=0
for k in range (36): #model num
  for j in range (31): #latitude
    for i in range (180): # longitude
      if t_test_1v1[169,k,j,i]==0:
        ET[k,j,i]=0
        continue
      elif np.any(t_test_1v1[:,k,j,i]==1):
        ET_value=max([count for count, item in enumerate(t_test_1v1[1:169,k,j,i]) if item == element], default=0)
        ET[k,j,i]=ET_value+1921
        continue
      else:
        ET[k,j,i]=1920

这是我的输入文件的示例:

array([[[[0, 0, 1, ..., 1, 1, 1],
         [0, 1, 1, ..., 0, 0, 0],
         [1, 1, 0, ..., 0, 0, 1],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],

        [[1, 1, 1, ..., 1, 1, 1],
         [1, 1, 1, ..., 1, 1, 1],
         [1, 1, 1, ..., 1, 1, 1],
Coordinates:(time: 240, deptht: 36, latitude: 31, longitude: 180)>
 * Time   (end_year) datetime64[ns] 1921-12-31 1922-12-31 ... 2100-12-31
 * deptht     (deptht) int64 1 2 3 4 5 6 7 8 9 ... 28 29 30 31 32 33 34 35 36
  * longitude  (longitude) float64 30.0 32.0 34.0 36.0 ... 384.0 386.0 388.0
  * latitude   (latitude) float64 -36.0 -34.0 -32.0 -30.0 ... 32.0 34.0 36.0

输出文件如下:

<xarray.DataArray (deptht:36, latitude: 37, longitude: 180)>
array([[1983., 2011., 2022., ..., 1937., 1937., 1962.],
       [2048., 2081., 2083., ...,    1920.,    0., 2011.],
       [2044., 1920., 1993., ...,    0.,    0.,    1920.],
       ...,
       [2004., 1993., 1993., ...,    0., 2010., 2011.],
       [1920., 1998., 1988., ..., 2011., 2014., 2014.],
       [2000.,    0.,    0., ..., 2014., 2011., 2000.]])
Coordinates:
 * deptht     (deptht) int64 1 2 3 4 5 6 7 8 9 ... 28 29 30 31 32 33 34 35 36
  * longitude  (longitude) float64 30.0 32.0 34.0 36.0 ... 384.0 386.0 388.0
  * latitude   (latitude) float64 -36.0 -34.0 -32.0 -30.0 ... 32.0 34.0 36.0

下面的代码

  1. 如果系列中只有零,return 0.
  2. 如果系列中只有 1,那么我想 return 1920。
  3. 如果将 ans 1 归零,则找到最后一个 0 的位置。
import numpy as np
import xarray as xr
import pandas as pd

# Generate 4D array to test
time_length = 240
depth_length = 36
longitude_length = 37
latitude_length = 180
nums = np.ones(time_length * depth_length * longitude_length * latitude_length)
nums[:175400] = 0
np.random.shuffle(nums)
nums = nums.reshape((time_length, depth_length, longitude_length, latitude_length))
times = pd.date_range("1921-01-01", periods=time_length, freq='y')
depth = np.arange(0, depth_length, 1)
longitude = np.random.random(longitude_length)
latitude = np.random.random(latitude_length)

foo = xr.DataArray(nums, coords=[times, depth, longitude, latitude], dims=["Time", "depth", "longitude", "latitude"])
time = xr.DataArray(np.arange(1921, 1921 + time_length, 1), coords=[times], dims="Time")
# print(time) the follow part will return the position of the maximum value for the axis 0 np.arange(0,
# time_length*depth_length*longitude_length*latitude_length,1).reshape(time_length, depth_length, longitude_length,
# latitude_length) is added so that argmax return the last maximum
ids = ((foo == 0) * np.arange(0, time_length * depth_length * longitude_length * latitude_length, 1).reshape(
    time_length, depth_length, longitude_length, latitude_length)).argmax(axis=0)
results = xr.DataArray(
    np.zeros(depth_length * longitude_length * latitude_length).reshape(depth_length, longitude_length,
                                                                                      latitude_length),
    coords=[depth, longitude, latitude], dims=[ "depth", "longitude", "latitude"])

for id, year in zip(np.arange(1, 241, 1), np.arange(1921, 1921 + time_length, 1)):
    results = results + ((ids == id) * year)
print(results)

# now the cases where it's all 0 or all 1

total = foo.sum(axis=0)
zeros_ids = np.argwhere(np.array(total == 0))
ones_ids = np.argwhere(np.array(total == time_length))
for indexes in ones_ids:
    for indexe in indexes:
        x0, x1, x2 = indexes
        results[x0][x1][x2] = 1920
for indexes in zeros_ids:
    for indexe in indexes:
        x0, x1, x2 = indexes
        results[x0][x1][x2] = 0
print(results)