如何使用 groupBy 创建一个值为 BigDecimal 字段平均值的地图?
How can I use groupBy to create a map whose values are averages of BigDecimal fields?
我有以下 class:
final public class Person {
private final String name;
private final String state;
private final BigDecimal salary;
public Person(String name, String state, BigDecimal salary) {
this.name = name;
this.state = state;
this.salary = salary;
}
//getters omitted for brevity...
}
我想创建一个地图,列出各州的平均工资。我如何使用 Java8 流来做到这一点?我尝试在 groupBy 上使用下游收集器,但无法以优雅的方式使用。
我做了以下工作,但看起来很丑陋:
Stream.of(p1,p2,p3,p4).collect(groupingBy(Person::getState, mapping(d -> d.getSalary(), toList())))
.forEach((state,wageList) -> {
System.out.print(state+"-> ");
final BigDecimal[] wagesArray = wageList.stream()
.map(bd -> new BigDecimal[]{bd, BigDecimal.ONE})
.reduce((a, b) -> new BigDecimal[]{a[0].add(b[0]), a[1].add(BigDecimal.ONE)})
.get();
System.out.println(wagesArray[0].divide(wagesArray[1])
.setScale(2, RoundingMode.CEILING));
});
有没有更好的方法?
这是一种方法。虽然我仍然相信可能还有其他更好的方法。
Map<String, Double> collect = Arrays.asList(p, p1, p2, p3, p4)
.stream()
.collect(groupingBy(k -> k.getState(), mapping(v -> v.salary, toList())))
.entrySet()
.stream()
.collect(toMap(k -> k.getKey(), v -> v.getValue().stream().mapToDouble(BigDecimal::doubleValue).average().getAsDouble()));
这是一个仅使用 BigDecimal 算法的完整示例,并展示了如何实现自定义收集器
import java.math.BigDecimal;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public final class Person {
private final String name;
private final String state;
private final BigDecimal salary;
public Person(String name, String state, BigDecimal salary) {
this.name = name;
this.state = state;
this.salary = salary;
}
public String getName() {
return name;
}
public String getState() {
return state;
}
public BigDecimal getSalary() {
return salary;
}
public static void main(String[] args) {
Person p1 = new Person("John", "NY", new BigDecimal("2000"));
Person p2 = new Person("Jack", "NY", new BigDecimal("3000"));
Person p3 = new Person("Jane", "GA", new BigDecimal("1500"));
Person p4 = new Person("Jackie", "GA", new BigDecimal("2500"));
Map<String, BigDecimal> result =
Stream.of(p1, p2, p3, p4).collect(
Collectors.groupingBy(Person::getState,
Collectors.mapping(Person::getSalary,
new AveragingCollector())));
System.out.println("result = " + result);
}
private static class AveragingCollector implements Collector<BigDecimal, IntermediateResult, BigDecimal> {
@Override
public Supplier<IntermediateResult> supplier() {
return IntermediateResult::new;
}
@Override
public BiConsumer<IntermediateResult, BigDecimal> accumulator() {
return IntermediateResult::add;
}
@Override
public BinaryOperator<IntermediateResult> combiner() {
return IntermediateResult::combine;
}
@Override
public Function<IntermediateResult, BigDecimal> finisher() {
return IntermediateResult::finish
}
@Override
public Set<Characteristics> characteristics() {
return Collections.emptySet();
}
}
private static class IntermediateResult {
private int count = 0;
private BigDecimal sum = BigDecimal.ZERO;
IntermediateResult() {
}
void add(BigDecimal value) {
this.sum = this.sum.add(value);
this.count++;
}
IntermediateResult combine(IntermediateResult r) {
this.sum = this.sum.add(r.sum);
this.count += r.count;
return this;
}
BigDecimal finish() {
return sum.divide(BigDecimal.valueOf(count), 2, BigDecimal.ROUND_HALF_UP);
}
}
}
如果您接受将 BigDecimal 值转换为双倍(对于平均工资,恕我直言,这是完全可以接受的),您可以只使用
Map<String, Double> result2 =
Stream.of(p1, p2, p3, p4).collect(
Collectors.groupingBy(Person::getState,
Collectors.mapping(Person::getSalary,
Collectors.averagingDouble(BigDecimal::doubleValue))));
注意:这不是最佳解决方案。阅读下文(从编辑开始)以获得更好的方法。
您可以使用 Collector.of
将每个州的工资累加到 ArrayList
中,然后使用 finisher 函数计算每个州的平均值:
Map<String, BigDecimal> salariesByState = Stream.of(p1, p2, p3, p4).collect(
Collectors.groupingBy(Person::getState,
Collectors.mapping(Person::getSalary,
Collector.<BigDecimal, List<BigDecimal>, BigDecimal>of(
ArrayList::new, // create accumulator
List::add, // add to accumulator
(l1, l2) -> { // combine two partial accumulators
l1.addAll(l2);
return l1;
},
l -> l.stream() // finish with a reduction that returns average
.reduce(BigDecimal.ZERO, BigDecimal::add)
.divide(BigDecimal.valueOf(l.size()))))));
这种方法不会丢失精度,因为每个计算每个状态平均值的操作都是在 BigDecimal
个实例上执行的。
编辑: 正如@Holger 在他的评论中指出的那样,这不是解决问题的最佳方法,因为使用 ArrayList
来存储所有 BigDecimal
实例根本不需要计算平均值。相反,使用 BigDecimal
来累加部分和并使用 long
来累加计数就足够了(这是@JB Nizet 在他的回答中遵循的方法,这里我修改了一些小细节).
这是考虑到这些因素的修改版本:
private static class Acc {
BigDecimal sum = BigDecimal.ZERO;
long count = 0;
void add(BigDecimal v) {
sum = sum.add(v);
count++;
}
Acc merge(Acc acc) {
sum = sum.add(acc.sum);
count += acc.count;
return this;
}
BigDecimal avg() {
return sum.divide(BigDecimal.valueOf(count));
}
}
Acc
是用来累加部分结果的class(包括部分和和计数)。
现在,我们可以使用 Collector.of
和 class:
Map<String, BigDecimal> salariesByState = Stream.of(p1, p2, p3, p4).collect(
Collectors.groupingBy(Person::getState,
Collectors.mapping(Person::getSalary,
Collector.of(Acc::new, Acc::add, Acc::merge, Acc::avg))));
或者更好的是,我们可以声明一个辅助方法,其中 Acc
class 是本地 class:
public static Collector<BigDecimal, ?, BigDecimal> averagingBigDecimal() {
class Acc { // local class, lives only inside this method :P
BigDecimal sum = BigDecimal.ZERO;
long count = 0;
void add(BigDecimal value) {
sum = sum.add(value);
count++;
}
Acc merge(Acc acc) {
sum = sum.add(acc.sum);
count += acc.count;
return this;
}
BigDecimal avg() {
return sum.divide(BigDecimal.valueOf(count));
}
}
return Collector.of(Acc::new, Acc::add, Acc::merge, Acc::avg);
}
此方法现在可以如下使用:
Map<String, BigDecimal> salariesByState = Stream.of(p1, p2, p3, p4).collect(
Collectors.groupingBy(Person::getState,
Collectors.mapping(Person::getSalary, averagingBigDecimal())));
如果您需要 BigDecimal
精度,并且不介意额外的迭代,您可以这样做:
static Map<String, BigDecimal> averageByState(List<Person> persons) {
// collect sums
Map<String, BigDecimal> sumByState = persons.stream()
.collect(groupingBy(
Person::getState,
HashMap::new,
mapping(Person::getSalary, reducing(BigDecimal.ZERO, BigDecimal::add))));
// collect counts
Map<String, Long> countByState = persons.stream()
.collect(groupingBy(Person::getState, counting()));
// merge
sumByState.replaceAll((state, sum) -> sum.divide(BigDecimal.valueOf(countByState.get(state))));
return sumByState;
}
我有以下 class:
final public class Person {
private final String name;
private final String state;
private final BigDecimal salary;
public Person(String name, String state, BigDecimal salary) {
this.name = name;
this.state = state;
this.salary = salary;
}
//getters omitted for brevity...
}
我想创建一个地图,列出各州的平均工资。我如何使用 Java8 流来做到这一点?我尝试在 groupBy 上使用下游收集器,但无法以优雅的方式使用。
我做了以下工作,但看起来很丑陋:
Stream.of(p1,p2,p3,p4).collect(groupingBy(Person::getState, mapping(d -> d.getSalary(), toList())))
.forEach((state,wageList) -> {
System.out.print(state+"-> ");
final BigDecimal[] wagesArray = wageList.stream()
.map(bd -> new BigDecimal[]{bd, BigDecimal.ONE})
.reduce((a, b) -> new BigDecimal[]{a[0].add(b[0]), a[1].add(BigDecimal.ONE)})
.get();
System.out.println(wagesArray[0].divide(wagesArray[1])
.setScale(2, RoundingMode.CEILING));
});
有没有更好的方法?
这是一种方法。虽然我仍然相信可能还有其他更好的方法。
Map<String, Double> collect = Arrays.asList(p, p1, p2, p3, p4)
.stream()
.collect(groupingBy(k -> k.getState(), mapping(v -> v.salary, toList())))
.entrySet()
.stream()
.collect(toMap(k -> k.getKey(), v -> v.getValue().stream().mapToDouble(BigDecimal::doubleValue).average().getAsDouble()));
这是一个仅使用 BigDecimal 算法的完整示例,并展示了如何实现自定义收集器
import java.math.BigDecimal;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public final class Person {
private final String name;
private final String state;
private final BigDecimal salary;
public Person(String name, String state, BigDecimal salary) {
this.name = name;
this.state = state;
this.salary = salary;
}
public String getName() {
return name;
}
public String getState() {
return state;
}
public BigDecimal getSalary() {
return salary;
}
public static void main(String[] args) {
Person p1 = new Person("John", "NY", new BigDecimal("2000"));
Person p2 = new Person("Jack", "NY", new BigDecimal("3000"));
Person p3 = new Person("Jane", "GA", new BigDecimal("1500"));
Person p4 = new Person("Jackie", "GA", new BigDecimal("2500"));
Map<String, BigDecimal> result =
Stream.of(p1, p2, p3, p4).collect(
Collectors.groupingBy(Person::getState,
Collectors.mapping(Person::getSalary,
new AveragingCollector())));
System.out.println("result = " + result);
}
private static class AveragingCollector implements Collector<BigDecimal, IntermediateResult, BigDecimal> {
@Override
public Supplier<IntermediateResult> supplier() {
return IntermediateResult::new;
}
@Override
public BiConsumer<IntermediateResult, BigDecimal> accumulator() {
return IntermediateResult::add;
}
@Override
public BinaryOperator<IntermediateResult> combiner() {
return IntermediateResult::combine;
}
@Override
public Function<IntermediateResult, BigDecimal> finisher() {
return IntermediateResult::finish
}
@Override
public Set<Characteristics> characteristics() {
return Collections.emptySet();
}
}
private static class IntermediateResult {
private int count = 0;
private BigDecimal sum = BigDecimal.ZERO;
IntermediateResult() {
}
void add(BigDecimal value) {
this.sum = this.sum.add(value);
this.count++;
}
IntermediateResult combine(IntermediateResult r) {
this.sum = this.sum.add(r.sum);
this.count += r.count;
return this;
}
BigDecimal finish() {
return sum.divide(BigDecimal.valueOf(count), 2, BigDecimal.ROUND_HALF_UP);
}
}
}
如果您接受将 BigDecimal 值转换为双倍(对于平均工资,恕我直言,这是完全可以接受的),您可以只使用
Map<String, Double> result2 =
Stream.of(p1, p2, p3, p4).collect(
Collectors.groupingBy(Person::getState,
Collectors.mapping(Person::getSalary,
Collectors.averagingDouble(BigDecimal::doubleValue))));
注意:这不是最佳解决方案。阅读下文(从编辑开始)以获得更好的方法。
您可以使用 Collector.of
将每个州的工资累加到 ArrayList
中,然后使用 finisher 函数计算每个州的平均值:
Map<String, BigDecimal> salariesByState = Stream.of(p1, p2, p3, p4).collect(
Collectors.groupingBy(Person::getState,
Collectors.mapping(Person::getSalary,
Collector.<BigDecimal, List<BigDecimal>, BigDecimal>of(
ArrayList::new, // create accumulator
List::add, // add to accumulator
(l1, l2) -> { // combine two partial accumulators
l1.addAll(l2);
return l1;
},
l -> l.stream() // finish with a reduction that returns average
.reduce(BigDecimal.ZERO, BigDecimal::add)
.divide(BigDecimal.valueOf(l.size()))))));
这种方法不会丢失精度,因为每个计算每个状态平均值的操作都是在 BigDecimal
个实例上执行的。
编辑: 正如@Holger 在他的评论中指出的那样,这不是解决问题的最佳方法,因为使用 ArrayList
来存储所有 BigDecimal
实例根本不需要计算平均值。相反,使用 BigDecimal
来累加部分和并使用 long
来累加计数就足够了(这是@JB Nizet 在他的回答中遵循的方法,这里我修改了一些小细节).
这是考虑到这些因素的修改版本:
private static class Acc {
BigDecimal sum = BigDecimal.ZERO;
long count = 0;
void add(BigDecimal v) {
sum = sum.add(v);
count++;
}
Acc merge(Acc acc) {
sum = sum.add(acc.sum);
count += acc.count;
return this;
}
BigDecimal avg() {
return sum.divide(BigDecimal.valueOf(count));
}
}
Acc
是用来累加部分结果的class(包括部分和和计数)。
现在,我们可以使用 Collector.of
和 class:
Map<String, BigDecimal> salariesByState = Stream.of(p1, p2, p3, p4).collect(
Collectors.groupingBy(Person::getState,
Collectors.mapping(Person::getSalary,
Collector.of(Acc::new, Acc::add, Acc::merge, Acc::avg))));
或者更好的是,我们可以声明一个辅助方法,其中 Acc
class 是本地 class:
public static Collector<BigDecimal, ?, BigDecimal> averagingBigDecimal() {
class Acc { // local class, lives only inside this method :P
BigDecimal sum = BigDecimal.ZERO;
long count = 0;
void add(BigDecimal value) {
sum = sum.add(value);
count++;
}
Acc merge(Acc acc) {
sum = sum.add(acc.sum);
count += acc.count;
return this;
}
BigDecimal avg() {
return sum.divide(BigDecimal.valueOf(count));
}
}
return Collector.of(Acc::new, Acc::add, Acc::merge, Acc::avg);
}
此方法现在可以如下使用:
Map<String, BigDecimal> salariesByState = Stream.of(p1, p2, p3, p4).collect(
Collectors.groupingBy(Person::getState,
Collectors.mapping(Person::getSalary, averagingBigDecimal())));
如果您需要 BigDecimal
精度,并且不介意额外的迭代,您可以这样做:
static Map<String, BigDecimal> averageByState(List<Person> persons) {
// collect sums
Map<String, BigDecimal> sumByState = persons.stream()
.collect(groupingBy(
Person::getState,
HashMap::new,
mapping(Person::getSalary, reducing(BigDecimal.ZERO, BigDecimal::add))));
// collect counts
Map<String, Long> countByState = persons.stream()
.collect(groupingBy(Person::getState, counting()));
// merge
sumByState.replaceAll((state, sum) -> sum.divide(BigDecimal.valueOf(countByState.get(state))));
return sumByState;
}