我如何使用GroupBy创建一个映射,其值是BigDecimal字段的平均值



我有以下类:

final public class Person {
 private final String name;
 private final String state;
 private final BigDecimal salary;
 public Person(String name, String state, BigDecimal salary) {
    this.name = name;
    this.state = state;
    this.salary = salary;
 }
 //getters omitted for brevity...
}

我想创建一个地图,该地图列出了按州按州的平均值。我该如何使用Java8流?我试图在Groupby上使用下游收藏家,但无法以优雅的方式这样做。

我做了以下内容,但看起来很丑陋:

Stream.of(p1,p2,p3,p4).collect(groupingBy(Person::getState, mapping(d -> d.getSalary(), toList())))
.forEach((state,wageList) -> {
        System.out.print(state+"-> ");
        final BigDecimal[] wagesArray = wageList.stream()
                .map(bd -> new BigDecimal[]{bd, BigDecimal.ONE})
                .reduce((a, b) -> new BigDecimal[]{a[0].add(b[0]), a[1].add(BigDecimal.ONE)})
                .get();
        System.out.println(wagesArray[0].divide(wagesArray[1])
                                        .setScale(2, RoundingMode.CEILING));
    });

有更好的方法吗?

这是一个仅使用bigdecimal算术的完整示例,并展示了如何实现自定义收集器

import java.math.BigDecimal;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public final class Person {
    private final String name;
    private final String state;
    private final BigDecimal salary;
    public Person(String name, String state, BigDecimal salary) {
        this.name = name;
        this.state = state;
        this.salary = salary;
    }
    public String getName() {
        return name;
    }
    public String getState() {
        return state;
    }
    public BigDecimal getSalary() {
        return salary;
    }
    public static void main(String[] args) {
        Person p1 = new Person("John", "NY", new BigDecimal("2000"));
        Person p2 = new Person("Jack", "NY", new BigDecimal("3000"));
        Person p3 = new Person("Jane", "GA", new BigDecimal("1500"));
        Person p4 = new Person("Jackie", "GA", new BigDecimal("2500"));
        Map<String, BigDecimal> result =
            Stream.of(p1, p2, p3, p4).collect(
                Collectors.groupingBy(Person::getState,
                                      Collectors.mapping(Person::getSalary,
                                                         new AveragingCollector())));
        System.out.println("result = " + result);
    }
    private static class AveragingCollector implements Collector<BigDecimal, IntermediateResult, BigDecimal> {
        @Override
        public Supplier<IntermediateResult> supplier() {
            return IntermediateResult::new;
        }
        @Override
        public BiConsumer<IntermediateResult, BigDecimal> accumulator() {
            return IntermediateResult::add;
        }
        @Override
        public BinaryOperator<IntermediateResult> combiner() {
            return IntermediateResult::combine;
        }
        @Override
        public Function<IntermediateResult, BigDecimal> finisher() {
            return IntermediateResult::finish
        }
        @Override
        public Set<Characteristics> characteristics() {
            return Collections.emptySet();
        }
    }
    private static class IntermediateResult {
        private int count = 0;
        private BigDecimal sum = BigDecimal.ZERO;
        IntermediateResult() {
        }
        void add(BigDecimal value) {
            this.sum = this.sum.add(value);
            this.count++;
        }
        IntermediateResult combine(IntermediateResult r) {
            this.sum = this.sum.add(r.sum);
            this.count += r.count;
            return this;
        }
        BigDecimal finish() {
            return sum.divide(BigDecimal.valueOf(count), 2, BigDecimal.ROUND_HALF_UP);
        }
    }
}

如果您接受将自己的大十分值转换为两倍(对于平均工资是完全可以接受的,恕我直言(,则可以只使用

Map<String, Double> result2 =
            Stream.of(p1, p2, p3, p4).collect(
                Collectors.groupingBy(Person::getState,
                                      Collectors.mapping(Person::getSalary,
                                                         Collectors.averagingDouble(BigDecimal::doubleValue))));

如果您需要BigDecimal精度,并且您不介意额外的迭代,则可以做到这一点:

static Map<String, BigDecimal> averageByState(List<Person> persons) {
    // collect sums
    Map<String, BigDecimal> sumByState = persons.stream()
            .collect(groupingBy(
                    Person::getState,
                    HashMap::new,
                    mapping(Person::getSalary, reducing(BigDecimal.ZERO, BigDecimal::add))));
    // collect counts
    Map<String, Long> countByState = persons.stream()
            .collect(groupingBy(Person::getState, counting()));
    // merge
    sumByState.replaceAll((state, sum) -> sum.divide(BigDecimal.valueOf(countByState.get(state))));
    return sumByState;
}

注意:这不是最好的解决方案。阅读下面的(从编辑开启(以获取更好的方法。

您可以使用Collector.of以每个国家的基础累积薪水到ArrayList,然后使用终结器函数来计算每个状态的平均值:

Map<String, BigDecimal> salariesByState = Stream.of(p1, p2, p3, p4).collect(
    Collectors.groupingBy(Person::getState,
        Collectors.mapping(Person::getSalary,
            Collector.<BigDecimal, List<BigDecimal>, BigDecimal>of(
                ArrayList::new, // create accumulator
                List::add,      // add to accumulator
                (l1, l2) -> {   // combine two partial accumulators
                    l1.addAll(l2);
                    return l1;
                },
                l -> l.stream() // finish with a reduction that returns average
                    .reduce(BigDecimal.ZERO, BigDecimal::add)
                    .divide(BigDecimal.valueOf(l.size()))))));

这种方法不会失去精度,因为每个操作都计算在BigDecimal实例上执行每个状态的平均值。

编辑:正如@holger在他的评论中指示的那样,这不是解决问题的最佳方法,因为使用ArrayList来存储所有BigDecimal实例,完全不需要计算平均。取而代之的是,使用BigDecimal积累部分总和,而long积累了计数(这是@jb nizet的方法,在他的答案中,我在这里修改了一些次要细节(。

(。

这是一个修改版本,考虑了这些注意事项:

private static class Acc {
    BigDecimal sum = BigDecimal.ZERO;
    long count = 0;
    void add(BigDecimal v) {
        sum = sum.add(v);
        count++;
    }
    Acc merge(Acc acc) {
        sum = sum.add(acc.sum);
        count += acc.count;
        return this;
    }
    BigDecimal avg() {
        return sum.divide(BigDecimal.valueOf(count));
    }
}

Acc是用于累积部分结果的类(部分总和和计数(。

现在,我们可以在此类中使用Collector.of

Map<String, BigDecimal> salariesByState = Stream.of(p1, p2, p3, p4).collect(
    Collectors.groupingBy(Person::getState,
        Collectors.mapping(Person::getSalary,
            Collector.of(Acc::new, Acc::add, Acc::merge, Acc::avg))));

甚至更好,我们可以声明一种辅助方法,而Acc类是本地类:

public static Collector<BigDecimal, ?, BigDecimal> averagingBigDecimal() {
    class Acc { // local class, lives only inside this method :P
        BigDecimal sum = BigDecimal.ZERO;
        long count = 0;
        void add(BigDecimal value) {
            sum = sum.add(value);
            count++;
        }
        Acc merge(Acc acc) {
            sum = sum.add(acc.sum);
            count += acc.count;
            return this;
        }
        BigDecimal avg() {
            return sum.divide(BigDecimal.valueOf(count));
        }
    }
    return Collector.of(Acc::new, Acc::add, Acc::merge, Acc::avg);
}

现在可以使用以下方法:

Map<String, BigDecimal> salariesByState = Stream.of(p1, p2, p3, p4).collect(
    Collectors.groupingBy(Person::getState,
        Collectors.mapping(Person::getSalary, averagingBigDecimal())));

这是一种方法。尽管我仍然相信可能还有其他更好的方法。

Map<String, Double> collect = Arrays.asList(p, p1, p2, p3, p4)
            .stream()
            .collect(groupingBy(k -> k.getState(), mapping(v -> v.salary, toList())))
            .entrySet()
            .stream()
            .collect(toMap(k -> k.getKey(), v -> v.getValue().stream().mapToDouble(BigDecimal::doubleValue).average().getAsDouble()));

最新更新