From: Harry Ramsey harry.ramsey@arm.com
The tcp_zerocopy_receive struct is used to create a shared memory region between userspace and the kernel in which received network packets are directly copied from the network interface card.
In the PCuABI, a user pointer is a 129-bit capability, so the __u64 type is not big enough to hold it. Use the __kernel_uintptr_t type instead, which is big enough on the affected architectures while remaining 64-bit on others. Additionally, use the special copy routine when copying user pointers from and to userspace.
An additional check has been added to ensure that the capability has ownership on the mapping backed by the TCP socket fd to receive network packets.
Signed-off-by: Harry Ramsey harry.ramsey@arm.com Co-developed-by: Tudor Cretu tudor.cretu@arm.com Signed-off-by: Tudor Cretu tudor.cretu@arm.com --- include/uapi/linux/tcp.h | 10 +++++----- net/ipv4/tcp.c | 40 +++++++++++++++++++++++----------------- 2 files changed, 28 insertions(+), 22 deletions(-)
diff --git a/include/uapi/linux/tcp.h b/include/uapi/linux/tcp.h index 879eeb0a084b..4da689910e0e 100644 --- a/include/uapi/linux/tcp.h +++ b/include/uapi/linux/tcp.h @@ -352,15 +352,15 @@ struct tcp_diag_md5sig {
#define TCP_RECEIVE_ZEROCOPY_FLAG_TLB_CLEAN_HINT 0x1 struct tcp_zerocopy_receive { - __u64 address; /* in: address of mapping */ - __u32 length; /* in/out: number of bytes to map/mapped */ - __u32 recv_skip_hint; /* out: amount of bytes to skip */ + __kernel_uintptr_t address; /* in: address of mapping */ + __u32 length; /* in/out: number of bytes to map/mapped */ + __u32 recv_skip_hint; /* out: amount of bytes to skip */ __u32 inq; /* out: amount of bytes in read queue */ __s32 err; /* out: socket error */ - __u64 copybuf_address; /* in: copybuf address (small reads) */ + __kernel_uintptr_t copybuf_address; /* in: copybuf address (small reads) */ __s32 copybuf_len; /* in/out: copybuf bytes avail/used or error */ __u32 flags; /* in: flags */ - __u64 msg_control; /* ancillary data */ + __kernel_uintptr_t msg_control; /* ancillary data */ __u64 msg_controllen; __u32 msg_flags; __u32 reserved; /* set to 0 for now */ diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index 1459ebf810d8..f3ee592a73e0 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -1968,15 +1968,15 @@ static int get_compat64_tcp_zerocopy_receive(struct tcp_zerocopy_receive *zc, if (copy_from_sockptr(&compat_zc, src, size)) return -EFAULT;
- zc->address = compat_zc.address; + zc->address = (__kernel_uintptr_t)compat_ptr(compat_zc.address); zc->length = compat_zc.length; zc->recv_skip_hint = compat_zc.recv_skip_hint; zc->inq = compat_zc.inq; zc->err = compat_zc.err; - zc->copybuf_address = compat_zc.copybuf_address; + zc->copybuf_address = (__kernel_uintptr_t)compat_ptr(compat_zc.copybuf_address); zc->copybuf_len = compat_zc.copybuf_len; zc->flags = compat_zc.flags; - zc->msg_control = compat_zc.msg_control; + zc->msg_control = (__kernel_uintptr_t)compat_ptr(compat_zc.msg_control); zc->msg_controllen = compat_zc.msg_controllen; zc->msg_flags = compat_zc.msg_flags; zc->reserved = compat_zc.reserved; @@ -1988,7 +1988,7 @@ static int copy_tcp_zerocopy_receive_from_sockptr(struct tcp_zerocopy_receive *z { if (in_compat64()) return get_compat64_tcp_zerocopy_receive(zc, src, size); - if (copy_from_sockptr(zc, src, size)) + if (copy_from_sockptr_with_ptr(zc, src, size)) return -EFAULT; return 0; } @@ -1998,15 +1998,15 @@ static int set_compat64_tcp_zerocopy_receive(sockptr_t dst, size_t size) { struct compat_tcp_zerocopy_receive compat_zc = { - .address = zc->address, + .address = (__u64)zc->address, .length = zc->length, .recv_skip_hint = zc->recv_skip_hint, .inq = zc->inq, .err = zc->err, - .copybuf_address = zc->copybuf_address, + .copybuf_address = (__u64)zc->copybuf_address, .copybuf_len = zc->copybuf_len, .flags = zc->flags, - .msg_control = zc->msg_control, + .msg_control = (__u64)zc->msg_control, .msg_controllen = zc->msg_controllen, .msg_flags = zc->msg_flags, .reserved = zc->reserved, @@ -2021,7 +2021,7 @@ static int copy_tcp_zerocopy_receive_to_sockptr(sockptr_t dst, { if (in_compat64()) return set_compat64_tcp_zerocopy_receive(dst, zc, size); - if (copy_to_sockptr(dst, zc, size)) + if (copy_to_sockptr_with_ptr(dst, zc, size)) return -EFAULT; return 0; } @@ -2070,7 +2070,7 @@ static int receive_fallback_to_copy(struct sock *sk, struct tcp_zerocopy_receive *zc, int inq, struct scm_timestamping_internal *tss) { - unsigned long copy_address = (unsigned long)zc->copybuf_address; + user_uintptr_t copy_address = (user_uintptr_t)zc->copybuf_address; struct msghdr msg = {}; struct iovec iov; int err; @@ -2081,7 +2081,7 @@ static int receive_fallback_to_copy(struct sock *sk, if (copy_address != zc->copybuf_address) return -EINVAL;
- err = import_single_range(ITER_DEST, uaddr_to_user_ptr(copy_address), + err = import_single_range(ITER_DEST, (void __user *)copy_address, inq, &iov, &msg.msg_iter); if (err) return err; @@ -2107,7 +2107,7 @@ static int tcp_copy_straggler_data(struct tcp_zerocopy_receive *zc, struct sk_buff *skb, u32 copylen, u32 *offset, u32 *seq) { - unsigned long copy_address = (unsigned long)zc->copybuf_address; + user_uintptr_t copy_address = (user_uintptr_t)zc->copybuf_address; struct msghdr msg = {}; struct iovec iov; int err; @@ -2115,7 +2115,7 @@ static int tcp_copy_straggler_data(struct tcp_zerocopy_receive *zc, if (copy_address != zc->copybuf_address) return -EINVAL;
- err = import_single_range(ITER_DEST, uaddr_to_user_ptr(copy_address), + err = import_single_range(ITER_DEST, (void __user *)copy_address, copylen, &iov, &msg.msg_iter); if (err) return err; @@ -2240,11 +2240,11 @@ static void tcp_zc_finalize_rx_tstamp(struct sock *sk, struct tcp_zerocopy_receive *zc, struct scm_timestamping_internal *tss) { - unsigned long msg_control_addr; + user_uintptr_t msg_control_addr; struct msghdr cmsg_dummy;
- msg_control_addr = (unsigned long)zc->msg_control; - cmsg_dummy.msg_control_user = uaddr_to_user_ptr(msg_control_addr); + msg_control_addr = zc->msg_control; + cmsg_dummy.msg_control_user = (void __user *)msg_control_addr; cmsg_dummy.msg_controllen = (__kernel_size_t)zc->msg_controllen; cmsg_dummy.msg_flags = in_compat_syscall() @@ -2254,8 +2254,8 @@ static void tcp_zc_finalize_rx_tstamp(struct sock *sk, if (zc->msg_control == msg_control_addr && zc->msg_controllen == cmsg_dummy.msg_controllen) { tcp_recv_timestamp(&cmsg_dummy, sk, tss); - zc->msg_control = (__u64) - user_ptr_addr(cmsg_dummy.msg_control_user); + zc->msg_control = + (__kernel_uintptr_t)cmsg_dummy.msg_control_user; zc->msg_controllen = (__u64)cmsg_dummy.msg_controllen; zc->msg_flags = (__u32)cmsg_dummy.msg_flags; @@ -2303,6 +2303,12 @@ static int tcp_zerocopy_receive(struct sock *sk, return 0; }
+#ifdef CONFIG_CHERI_PURECAP_UABI + /* Check that the pointer has ownership of the mapping */ + if (!cheri_check_cap((void __user *)zc->address, zc->length, CHERI_PERM_SW_VMEM)) + return -EINVAL; +#endif + mmap_read_lock(current->mm);
vma = vma_lookup(current->mm, address);