// SPDX-License-Identifier: GPL-2.0-only
/*
* ntsync.c - Kernel driver for NT synchronization primitives
*
* Copyright (C) 2024 Elizabeth Figura <zfigura@codeweavers.com>
*/
#include <linux/anon_inodes.h>
#include <linux/atomic.h>
#include <linux/file.h>
#include <linux/fs.h>
#include <linux/hrtimer.h>
#include <linux/ktime.h>
#include <linux/miscdevice.h>
#include <linux/module.h>
#include <linux/mutex.h>
#include <linux/overflow.h>
#include <linux/sched.h>
#include <linux/sched/signal.h>
#include <linux/slab.h>
#include <linux/spinlock.h>
#include <uapi/linux/ntsync.h>
#define NTSYNC_NAME "ntsync"
enum ntsync_type {
NTSYNC_TYPE_SEM,
NTSYNC_TYPE_MUTEX,
NTSYNC_TYPE_EVENT,
};
/*
* Individual synchronization primitives are represented by
* struct ntsync_obj, and each primitive is backed by a file.
*
* The whole namespace is represented by a struct ntsync_device also
* backed by a file.
*
* Both rely on struct file for reference counting. Individual
* ntsync_obj objects take a reference to the device when created.
* Wait operations take a reference to each object being waited on for
* the duration of the wait.
*/
struct ntsync_obj {
spinlock_t lock;
int dev_locked;
enum ntsync_type type;
struct file *file;
struct ntsync_device *dev;
/* The following fields are protected by the object lock. */
union {
struct {
__u32 count;
__u32 max;
} sem;
struct {
__u32 count;
pid_t owner;
bool ownerdead;
} mutex;
struct {
bool manual;
bool signaled;
} event;
} u;
/*
* any_waiters is protected by the object lock, but all_waiters is
* protected by the device wait_all_lock.
*/
struct list_head any_waiters;
struct list_head all_waiters;
/*
* Hint describing how many tasks are queued on this object in a
* wait-all operation.
*
* Any time we do a wake, we may need to wake "all" waiters as well as
* "any" waiters. In order to atomically wake "all" waiters, we must
* lock all of the objects, and that means grabbing the wait_all_lock
* below (and, due to lock ordering rules, before locking this object).
* However, wait-all is a rare operation, and grabbing the wait-all
* lock for every wake would create unnecessary contention.
* Therefore we first check whether all_hint is zero, and, if it is,
* we skip trying to wake "all" waiters.
*
* Since wait requests must originate from user-space threads, we're
* limited here by PID_MAX_LIMIT, so there's no risk of overflow.
*/
atomic_t all_hint;
};
struct ntsync_q_entry {
struct list_head node;
struct ntsync_q *q;
struct ntsync_obj *obj;
__u32 index;
};
struct ntsync_q {
struct task_struct *task;
__u32 owner;
/*
* Protected via atomic_try_cmpxchg(). Only the thread that wins the
* compare-and-swap may actually change object states and wake this
* task.
*/
atomic_t signaled;
bool all;
bool ownerdead;
__u32 count;
struct ntsync_q_entry entries[];
};
struct ntsync_device {
/*
* Wait-all operations must atomically grab all objects, and be totally
* ordered with respect to each other and wait-any operations.
* If one thread is trying to acquire several objects, another thread
* cannot touch the object at the same time.
*
* This device-wide lock is used to serialize wait-for-all
* operations, and operations on an object that is involved in a
* wait-for-all.
*/
struct mutex wait_all_lock;
struct file *file;
};
/*
* Single objects are locked using obj->lock.
*
* Multiple objects are 'locked' while holding dev->wait_all_lock.
* In this case however, individual objects are not locked by holding
* obj->lock, but by setting obj->dev_locked.
*
* This means that in order to lock a single object, the sequence is slightly
* more complicated than usual. Specifically it needs to check obj->dev_locked
* after acquiring obj->lock, if set, it needs to drop the lock and acquire
* dev->wait_all_lock in order to serialize against the multi-object operation.
*/
static void dev_lock_obj(struct ntsync_device *dev, struct ntsync_obj *obj)
{
lockdep_assert_held(&dev->wait_all_lock);
lockdep_assert(obj->dev == dev);
spin_lock(&obj->lock);
/*
* By setting obj->dev_locked inside obj->lock, it is ensured that
* anyone holding obj->lock must see the value.
*/
obj->dev_locked = 1;
spin_unlock(&obj->lock);
}
static void dev_unlock_obj(struct ntsync_device *dev, struct ntsync_obj *obj)
{
lockdep_assert_held(&dev->wait_all_lock);
lockdep_assert(obj->dev == dev);
spin_lock(&obj->lock);
obj->dev_locked = 0;
spin_unlock(&obj->lock);
}
static void obj_lock(struct ntsync_obj *obj)
{
struct ntsync_device *dev = obj->dev;
for (;;) {
spin_lock(&obj->lock);
if (likely(!obj->dev_locked))
break;
spin_unlock(&obj->lock);
mutex_lock(&dev->wait_all_lock);
spin_lock(&obj