无法在单元测试中使用 ModelMapper 将 Mock 对象转换为 Entity 对象

Mock object cannot be cast to Entity object using ModelMapper in unit test

我正在对一个服务方法进行单元测试create product。此方法接收一个 Dto 对象,然后使用 ModelMapper 将 Dto 对象传输到 ProductEntity 对象。这个方法工作正常但是当我对它进行单元测试时,我模拟了 modelMapper 并像这样存根它 when(modelMapper.map(createProductRequestDto, ProductEntity.class)).thenReturn(expectedProduct) with expectedProduct 是一个 ProductEntity 实例,它抛出 java.lang.ClassCastException

堆栈跟踪:

java.lang.ClassCastException: class com.example.demo.dto.responses.product.ProductResponseDto$MockitoMock7533970 cannot be cast to class com.example.demo.entities.ProductEntity (com.example.demo.dto.responses.product.ProductResponseDto$MockitoMock7533970 and com.example.demo.entities.ProductEntity are in unnamed module of loader 'app')

    at com.example.demo.services.implementations.product.ProductCrudServiceImpl.createProduct(ProductCrudServiceImpl.java:78)
    at com.example.demo.services.product.ProductCrudServiceImplTest.createProduct_ShouldReturnProductResponseDto(ProductCrudServiceImplTest.java:136)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.base/java.lang.reflect.Method.invoke(Method.java:566)
    at org.junit.platform.commons.util.ReflectionUtils.invokeMethod(ReflectionUtils.java:688)
    at org.junit.jupiter.engine.execution.MethodInvocation.proceed(MethodInvocation.java:60)
    at org.junit.jupiter.engine.execution.InvocationInterceptorChain$ValidatingInvocation.proceed(InvocationInterceptorChain.java:131)
    at org.junit.jupiter.engine.extension.TimeoutExtension.intercept(TimeoutExtension.java:149)
    at org.junit.jupiter.engine.extension.TimeoutExtension.interceptTestableMethod(TimeoutExtension.java:140)
    at org.junit.jupiter.engine.extension.TimeoutExtension.interceptTestMethod(TimeoutExtension.java:84)
    at org.junit.jupiter.engine.execution.ExecutableInvoker$ReflectiveInterceptorCall.lambda$ofVoidMethod[=10=](ExecutableInvoker.java:115)
    at org.junit.jupiter.engine.execution.ExecutableInvoker.lambda$invoke[=10=](ExecutableInvoker.java:105)
    at org.junit.jupiter.engine.execution.InvocationInterceptorChain$InterceptedInvocation.proceed(InvocationInterceptorChain.java:106)
    at org.junit.jupiter.engine.execution.InvocationInterceptorChain.proceed(InvocationInterceptorChain.java:64)
    at org.junit.jupiter.engine.execution.InvocationInterceptorChain.chainAndInvoke(InvocationInterceptorChain.java:45)
    at org.junit.jupiter.engine.execution.InvocationInterceptorChain.invoke(InvocationInterceptorChain.java:37)
    at org.junit.jupiter.engine.execution.ExecutableInvoker.invoke(ExecutableInvoker.java:104)
    at org.junit.jupiter.engine.execution.ExecutableInvoker.invoke(ExecutableInvoker.java:98)
    at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.lambda$invokeTestMethod(TestMethodTestDescriptor.java:210)
    at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
    at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.invokeTestMethod(TestMethodTestDescriptor.java:206)
    at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.execute(TestMethodTestDescriptor.java:131)
    at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.execute(TestMethodTestDescriptor.java:65)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively(NodeTestTask.java:139)
    at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively(NodeTestTask.java:129)
    at org.junit.platform.engine.support.hierarchical.Node.around(Node.java:137)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively(NodeTestTask.java:127)
    at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.executeRecursively(NodeTestTask.java:126)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.execute(NodeTestTask.java:84)
    at java.base/java.util.ArrayList.forEach(ArrayList.java:1541)
    at org.junit.platform.engine.support.hierarchical.SameThreadHierarchicalTestExecutorService.invokeAll(SameThreadHierarchicalTestExecutorService.java:38)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively(NodeTestTask.java:143)
    at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively(NodeTestTask.java:129)
    at org.junit.platform.engine.support.hierarchical.Node.around(Node.java:137)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively(NodeTestTask.java:127)
    at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.executeRecursively(NodeTestTask.java:126)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.execute(NodeTestTask.java:84)
    at java.base/java.util.ArrayList.forEach(ArrayList.java:1541)
    at org.junit.platform.engine.support.hierarchical.SameThreadHierarchicalTestExecutorService.invokeAll(SameThreadHierarchicalTestExecutorService.java:38)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively(NodeTestTask.java:143)
    at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively(NodeTestTask.java:129)
    at org.junit.platform.engine.support.hierarchical.Node.around(Node.java:137)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively(NodeTestTask.java:127)
    at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.executeRecursively(NodeTestTask.java:126)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.execute(NodeTestTask.java:84)
    at org.junit.platform.engine.support.hierarchical.SameThreadHierarchicalTestExecutorService.submit(SameThreadHierarchicalTestExecutorService.java:32)
    at org.junit.platform.engine.support.hierarchical.HierarchicalTestExecutor.execute(HierarchicalTestExecutor.java:57)
    at org.junit.platform.engine.support.hierarchical.HierarchicalTestEngine.execute(HierarchicalTestEngine.java:51)
    at org.junit.platform.launcher.core.EngineExecutionOrchestrator.execute(EngineExecutionOrchestrator.java:108)
    at org.junit.platform.launcher.core.EngineExecutionOrchestrator.execute(EngineExecutionOrchestrator.java:88)
    at org.junit.platform.launcher.core.EngineExecutionOrchestrator.lambda$execute[=10=](EngineExecutionOrchestrator.java:54)
    at org.junit.platform.launcher.core.EngineExecutionOrchestrator.withInterceptedStreams(EngineExecutionOrchestrator.java:67)
    at org.junit.platform.launcher.core.EngineExecutionOrchestrator.execute(EngineExecutionOrchestrator.java:52)
    at org.junit.platform.launcher.core.DefaultLauncher.execute(DefaultLauncher.java:96)
    at org.junit.platform.launcher.core.DefaultLauncher.execute(DefaultLauncher.java:75)
    at com.intellij.junit5.JUnit5IdeaTestRunner.startRunnerWithArgs(JUnit5IdeaTestRunner.java:71)
    at com.intellij.rt.junit.IdeaTestRunner$Repeater.execute(IdeaTestRunner.java:38)
    at com.intellij.rt.execution.junit.TestsRepeater.repeat(TestsRepeater.java:11)
    at com.intellij.rt.junit.IdeaTestRunner$Repeater.startRunnerWithArgs(IdeaTestRunner.java:35)
    at com.intellij.rt.junit.JUnitStarter.prepareStreamsAndStart(JUnitStarter.java:235)
    at com.intellij.rt.junit.JUnitStarter.main(JUnitStarter.java:54)

createProduct 服务方法:

@Override
    public ProductResponseDto createProduct(CreateProductRequestDto createProductRequestDTO) {
        Long genderId = createProductRequestDTO.getGenderId();
        Long sportId = createProductRequestDTO.getSportId();
        List<Long> categoryIds = createProductRequestDTO.getCategoryIds();
        List<Long> technologyIds = createProductRequestDTO.getTechnologyIds();

        GenderEntity genderEntity = genderCrudService.findById(genderId);
        SportEntity sportEntity = sportCrudService.findById(sportId);
        List<CategoryEntity> categoryEntities = categoryCrudService.findByIds(categoryIds);
        List<TechnologyEntity> technologyEntities = technologyService.findByIds(technologyIds);
        Set<TechnologyEntity> technologyEntitySet = technologyEntities.stream().collect(Collectors.toSet());
        Set<CategoryEntity> categoryEntitySet = categoryEntities.stream().collect(Collectors.toSet());

        ProductEntity productEntity = modelMapper.map(createProductRequestDTO, ProductEntity.class);
        productEntity.setTechnologies(technologyEntitySet);
        productEntity.setCategories(categoryEntitySet);
        productEntity.setGender(genderEntity);
        productEntity.setSport(sportEntity);

        productEntity = productRepository.save(productEntity);

        AddSizeToProductRequestDto requestDto = addSizeToProductRequestDtoFactory.createAddSizeToRequestDto(productEntity.getId(), createProductRequestDTO.getProductSizeDtoList());
        productEntity = productSizeService.addSizeToProduct(requestDto);

        return modelMapper.map(productEntity, ProductResponseDto.class);
    }

单元测试class:

public class ProductCrudServiceImplTest {
    ModelMapper modelMapper;
    ProductCrudServiceImpl productCrudServiceImpl;
    GenderCrudService genderCrudService;
    SportCrudService sportCrudService;
    CategoryCrudService categoryCrudService;
    ProductRepository productRepository;
    TechnologyService technologyService;
    ProductSizeService productSizeService;
    CreateProductRequestDto createProductRequestDTO;
    List<Long> categoryIds;
    List<Long> technologyIds;
    GenderEntity genderEntity;
    SportEntity sportEntity;
    List<CategoryEntity> categoryEntities;
    List<TechnologyEntity> technologyEntities;
    Set<TechnologyEntity> technologyEntitiesSet;
    Set<CategoryEntity> categoryEntitiesSet;
    List<ProductSizeDto> productSizeDtoList;
    ProductEntity initProduct;
    ProductEntity expectedProduct;
    AddSizeToProductRequestDto addSizeToProductRequestDto;
    AddSizeToProductRequestDtoFactory addSizeToProductRequestDtoFactory;
    ProductResponseDto resultProductDto;

    @BeforeEach
    public void beforeEach() {
        modelMapper = mock(ModelMapper.class);
        genderCrudService = mock(GenderCrudService.class);
        sportCrudService = mock(SportCrudService.class);
        categoryCrudService = mock(CategoryCrudService.class);
        productRepository = mock(ProductRepository.class);
        technologyService = mock(TechnologyService.class);
        productSizeService = mock(ProductSizeService.class);
        addSizeToProductRequestDto = mock(AddSizeToProductRequestDto.class);
        addSizeToProductRequestDtoFactory = mock(AddSizeToProductRequestDtoFactory.class);
        productCrudServiceImpl = new ProductCrudServiceImpl(
                modelMapper,
                genderCrudService,
                sportCrudService,
                categoryCrudService,
                productRepository,
                technologyService,
                productSizeService,
                addSizeToProductRequestDtoFactory
        );

        createProductRequestDTO = mock(CreateProductRequestDto.class);
        genderEntity = mock(GenderEntity.class);
        sportEntity = mock(SportEntity.class);
        categoryIds = mock(List.class);
        technologyIds = mock(List.class);
        categoryEntities = Arrays.asList(CategoryEntity.builder().id(1L).name("category").description("description").build());
        categoryEntitiesSet = categoryEntities.stream().collect(Collectors.toSet());
        technologyEntities = Arrays.asList(TechnologyEntity.builder().id(1L).name("technology").description("description").build());
        technologyEntitiesSet = technologyEntities.stream().collect(Collectors.toSet());
        expectedProduct = ProductEntity.builder()
                .id(1L)
                .name("product")
                .price(1)
                .year(2022)
                .thumbnail("thumbnail")
                .sport(sportEntity)
                .gender(genderEntity)
                .technologies(technologyEntitiesSet)
                .categories(categoryEntitiesSet)
                .description("description")
                .build();
        productSizeDtoList = mock(List.class);
        addSizeToProductRequestDto = mock(AddSizeToProductRequestDto.class);
        addSizeToProductRequestDtoFactory = mock(AddSizeToProductRequestDtoFactory.class);
        resultProductDto = mock(ProductResponseDto.class);

        when(createProductRequestDTO.getGenderId()).thenReturn(1L);
        when(createProductRequestDTO.getSportId()).thenReturn(2L);
        when(createProductRequestDTO.getCategoryIds()).thenReturn(categoryIds);
        when(createProductRequestDTO.getTechnologyIds()).thenReturn(technologyIds);

        when(genderCrudService.findById(1L)).thenReturn(genderEntity);
        when(sportCrudService.findById(1L)).thenReturn(sportEntity);
        when(categoryCrudService.findByIds(categoryIds)).thenReturn(categoryEntities);
        when(technologyService.findByIds(technologyIds)).thenReturn(technologyEntities);

        when(modelMapper.map(createProductRequestDTO, ProductEntity.class)).thenReturn(expectedProduct);
        when(productRepository.save(any())).thenReturn(expectedProduct);

        when(createProductRequestDTO.getProductSizeDtoList()).thenReturn(productSizeDtoList);
        when(addSizeToProductRequestDtoFactory.createAddSizeToRequestDto(1L, productSizeDtoList)).thenReturn(addSizeToProductRequestDto);
        when(productSizeService.addSizeToProduct(addSizeToProductRequestDto)).thenReturn(expectedProduct);
        when(modelMapper.map(any(), ArgumentMatchers.<Class<ProductResponseDto>>any())).thenReturn(resultProductDto);
    }

    @Test
    public void createProduct_ShouldReturnProductResponseDto() {
        ProductResponseDto result = productCrudServiceImpl.createProduct(createProductRequestDTO);

        ArgumentCaptor<ProductEntity> productCaptor = ArgumentCaptor.forClass(ProductEntity.class);
        verify(productRepository).save(productCaptor.capture());
        ProductEntity savedProduct = productCaptor.getValue();
        verify(modelMapper).map(productCaptor.capture(), ProductEntity.class);
        ProductEntity mappedProduct = productCaptor.getValue();
//
        assertEquals(savedProduct, expectedProduct);
        assertEquals(mappedProduct, expectedProduct);
        assertThat(result, is(resultProductDto));
    }

我对此进行了一些研究,但无法确定根本原因,也不知道这是怎么发生的。你能帮我解决这个问题吗?

非常感谢您抽出宝贵的时间。

您的 beforeEach 方法 when(modelMapper.map(any(), ArgumentMatchers.<Class<ProductResponseDto>>any())).thenReturn(resultProductDto); 中的最后一个模拟似乎是导致问题的原因。
如堆栈跟踪所示,从模拟存根返回的对象的类型为 class com.example.demo.dto.responses.product.ProductResponseDto$MockitoMock7533970。 您可以通过检查哈希码来验证它是否与在 resultProductDto = mock(ProductResponseDto.class); 处创建的对象相同。

作为修复,为了模拟 modelMapper.map(productEntity, ProductResponseDto.class); 的行为,您应该尝试使用 when(modelMapper.map(any(ProductEntity.class)...