在 R 中,如何按行使用 argmax 函数并计算变量数?

In R, how can I use argmax function by row and count the variable number?

我有原始数据集。以下是原始数据样本:

          sentiment
     pos    neu    neg    likes_count
1     1      0      0         5
2     0.2   0.3    0.5        6 
3     0.3   0.3    0.4        6
4     0      0      1         3
5     0.2   0.7    0.1        1

在这个情绪原始数据中,"pos"指的是评论中正面的概率,"neu"指的是中立的概率,"neg"指的是负面的概率。 "likes_count" 是Facebook 评论的点赞数。我想在 pos、neu 和 neg 中选择概率最高的。并知道哪种情绪获得最高 likes_count。 例如pos : 0.6, neu : 0.2, neg : 0.2为正面评价。

我想要的输出如下:

    Total_likes_count
Pos        5
Neu        1
Neg        15

你能帮我做这个吗??

dput原始数据格式:

structure(list(likes_count = c(0L, 0L, 1L, 0L, 0L, 0L, 1L, 1L, 
0L, 0L, 0L, 3L, 1L, 2L, 2L, 1L, 1L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 
2L, 5L, 2L, 0L, 1L, 0L), neg = c(0, 0.41, 0, 0, 0, 
0.19, 0, 1, 0, 0.52, 0, 0.11, 0.02, 0.05, 0.02, 0, 0, 0.01, 0.01, 
0, 0, 0.97, 0, 0.01, 0.24, 0.34, 0.94, 0.44, 0.15, 0.01), neu = c(0, 
0.1, 0, 0, 0, 0, 0.13, 0, 0.98, 0.32, 0, 0.08, 0.02, 0.04, 0.07, 
0, 0, 0.98, 0.07, 0, 0, 0.03, 0.02, 0.21, 0.48, 0.62, 0.01, 0.2, 
0.85, 0.67), pos = c(1, 0.48, 1, 1, 1, 0.81, 0.86, 
0, 0.02, 0.16, 1, 0.81, 0.96, 0.91, 0.91, 1, 1, 0.01, 0.92, 1, 
1, 0, 0.98, 0.78, 0.28, 0.04, 0.05, 0.36, 0, 0.32)), na.action = structure(c(`7` = 7L, 
`11` = 11L, `38` = 38L, `53` = 53L, `88` = 88L, `101` = 101L, 
`106` = 106L, `138` = 138L, `139` = 139L, `155` = 155L, `165` = 165L, 
`176` = 176L, `178` = 178L, `179` = 179L, `199` = 199L, `200` = 200L, 
`201` = 201L, `208` = 208L, `209` = 209L, `250` = 250L, `281` = 281L, 
`293` = 293L, `299` = 299L, `316` = 316L, `321` = 321L, `322` = 322L, 
`328` = 328L, `332` = 332L, `333` = 333L, `334` = 334L, `335` = 335L, 
`336` = 336L, `342` = 342L, `347` = 347L, `352` = 352L, `354` = 354L, 
`355` = 355L, `395` = 395L, `398` = 398L, `400` = 400L, `411` = 411L, 
`420` = 420L, `449` = 449L, `454` = 454L, `456` = 456L, `457` = 457L, 
`464` = 464L, `471` = 471L, `491` = 491L, `495` = 495L, `502` = 502L, 
`503` = 503L, `504` = 504L, `506` = 506L, `526` = 526L, `536` = 536L, 
`541` = 541L, `542` = 542L, `546` = 546L, `556` = 556L, `558` = 558L, 
`563` = 563L, `579` = 579L, `581` = 581L, `582` = 582L, `584` = 584L, 
`602` = 602L, `603` = 603L, `604` = 604L, `606` = 606L, `614` = 614L, 
`617` = 617L, `619` = 619L, `620` = 620L, `621` = 621L, `622` = 622L, 
`623` = 623L, `625` = 625L, `626` = 626L, `629` = 629L, `630` = 630L, 
`631` = 631L, `632` = 632L, `633` = 633L, `636` = 636L, `637` = 637L, 
`638` = 638L, `639` = 639L, `640` = 640L, `643` = 643L, `645` = 645L, 
`646` = 646L, `647` = 647L, `648` = 648L, `650` = 650L, `652` = 652L, 
`653` = 653L, `655` = 655L, `656` = 656L, `658` = 658L, `661` = 661L, 
`665` = 665L, `666` = 666L, `667` = 667L, `669` = 669L, `671` = 671L, 
`673` = 673L, `674` = 674L, `679` = 679L, `680` = 680L, `682` = 682L, 
`683` = 683L, `684` = 684L, `685` = 685L, `686` = 686L, `687` = 687L, 
`689` = 689L, `692` = 692L, `694` = 694L, `696` = 696L, `697` = 697L, 
`699` = 699L, `700` = 700L, `701` = 701L, `702` = 702L, `703` = 703L, 
`704` = 704L, `705` = 705L, `707` = 707L, `708` = 708L, `712` = 712L, 
`713` = 713L, `714` = 714L, `717` = 717L, `718` = 718L, `719` = 719L, 
`720` = 720L, `721` = 721L, `722` = 722L, `723` = 723L, `724` = 724L, 
`725` = 725L, `726` = 726L, `727` = 727L, `728` = 728L, `730` = 730L, 
`738` = 738L, `750` = 750L, `753` = 753L, `754` = 754L, `761` = 761L, 
`766` = 766L, `767` = 767L, `769` = 769L, `771` = 771L, `775` = 775L, 
`786` = 786L, `808` = 808L, `810` = 810L, `812` = 812L, `814` = 814L, 
`817` = 817L, `820` = 820L, `841` = 841L, `862` = 862L, `864` = 864L, 
`865` = 865L, `866` = 866L, `867` = 867L, `874` = 874L, `877` = 877L, 
`878` = 878L, `881` = 881L, `882` = 882L, `890` = 890L, `891` = 891L, 
`913` = 913L, `934` = 934L, `938` = 938L, `951` = 951L, `961` = 961L, 
`962` = 962L, `967` = 967L, `971` = 971L, `972` = 972L, `981` = 981L, 
`983` = 983L, `986` = 986L, `988` = 988L, `1000` = 1000L, `1014` = 1014L
), class = "omit"), row.names = c(NA, -30L), class = "data.frame")

您可以使用max.col 来查看哪一列具有最大值。

data.table

library(data.table)

data.table(likes = df$likes_count, maxcol = names(df)[-1][max.col(df[-1])]
)[, .(likes = sum(likes)), maxcol]

#    maxcol likes
# 1:    pos    12
# 2:    neg     3
# 3:    neu     8

dplyr

library(dplyr)

tibble(likes = df$likes_count, maxcol = names(df)[-1][max.col(df[-1])]) %>% 
  group_by(maxcol) %>% 
  summarise(likes = sum(likes))

# # A tibble: 3 x 2
#   maxcol likes
#   <chr>  <int>
# 1 neg        3
# 2 neu        8
# 3 pos       12

我们可以使用

table(names(df1)[-1][max.col(df1[-1], 'first')])

如果要得到'likes_count'的sum一个base R选项是

rowsum(df1$likes_count, group = names(df1)[-1][max.col(df1[-1], 'first')])
#     [,1]
#neg    3
#neu    8
#pos   12

或使用 pivot_longer 转换为长格式并获得加权计数

library(dplyr)
library(tidyr)
df1 %>% 
  mutate(rn = row_number()) %>% 
  pivot_longer(cols = neg:pos) %>% group_by(rn) %>% 
  slice(which.max(value)) %>% 
  ungroup %>% 
  count(name, wt = likes_count)
# A tibble: 3 x 2
#  name      n
#  <chr> <int>
#1 neg       3
#2 neu       8
#3 pos      12