关于 data.table:R – 在巨大的 data.frame 中改变条件

R - mutate condition in huge data.frame

所以我有非常大的数据集(>1000 obs. of >15000 variables),我不想用 1 替换所有值 >1 并保持其余部分不变。

示例数据:

1
2
3
4
5
6
7
8
9
10
11
12
13
data <- data.frame(a = 1:10, b = -1:-10, c = letters[1:10])

    a   b c
1   1  -1 a
2   2  -2 b
3   3  -3 c
4   4  -4 d
5   5  -5 e
6   6  -6 f
7   7  -7 g
8   8  -8 h
9   9  -9 i
10 10 -10 j

这是我的dplyr方法:

1
2
3
4
5
6
7
data %>% mutate_if(is.numeric,
                                   funs(
                                     case_when(
                                       . >= 1 ~ 1,
                                       TRUE ~ as.double(.))
                                     )
                                   )

这需要很长时间才能处理原始数据。知道如何加快速度吗? data.table?


这个带有 data.table 的解决方案似乎有效,公平地说,它给出了一个 warning():

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
library(data.table)
library(purrr)
num_cols <- colnames(data)[map_lgl(data, is.numeric)] # select only the numerics

data[, (num_cols):= lapply(.SD, function(x) {
                                    x[x>1] = 1
                                    x}),
     .SDcols=num_cols
     ]
data
# a aa   b c
# 1: 1  1  -1 a
# 2: 1  1  -2 b
# 3: 1  1  -3 c
# 4: 1  1  -4 d
# 5: 1  1  -5 e
# 6: 1  1  -6 f
# 7: 1  1  -7 g
# 8: 1  1  -8 h
# 9: 1  1  -9 i
# 10: 1  1 -10 j

Warning message: In [.data.table(data, , :=((num_cols),
lapply(.SD, function(x) { : Supplied 2 columns to be assigned a list
(length 3) of values (1 unused)

使用的数据:

1
data <- data.table(a = 1:10, aa = 1:10, b = -1:-10, c = letters[1:10])

基准:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
microbenchmark::microbenchmark(
  dplyr = data %>% mutate_if(is.numeric,
                              funs(
                                case_when(
                                  . >= 1 ~ 1,
                                  TRUE ~ as.double(.))
                              )
  ),
  datatable = data[, (num_cols):= lapply(.SD, function(x) {
    x[x>1] = 1
    x})
    ],
  times = 100
)

# Unit: microseconds
# expr      min        lq      mean    median        uq       max neval
# dplyr 1465.088 1644.7690 2012.3148 1775.4730 1989.1065 19992.621   100
# datatable  372.282  399.0235  480.9405  440.0375  547.3055   831.398   100

公平地说,更新 Ronak Shah 解决方案更快:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
microbenchmark::microbenchmark(
  dplyr = data %>% mutate_if(is.numeric,
                              funs(
                                case_when(
                                  . >= 1 ~ 1,
                                  TRUE ~ as.double(.))
                              )
  ),
  datatable = data[, (num_cols):= lapply(.SD, function(x) {
    x[x>1] = 1
    x})
    ],
  base = {dataframe <- as.data.frame(data)
          dataframe[dataframe > 1] <- 1},
  times = 100
)
# Unit: microseconds
# expr      min        lq      mean   median        uq       max neval
# dplyr 1782.384 1902.1210 2549.3977 1995.116 2099.9800 55628.570   100
# datatable  394.817  422.7605  466.5329  441.690  512.9020   628.282   100
# base  118.987  135.5120  160.1595  154.291  176.2255   300.469   100


你可以试试:

1
2
apply(data[, which(sapply(data, is.numeric))], 2,
      function(x) {ifelse(x > 1, 1, x)})

它省略了 c 列,但之后您可以轻松地合并它。