From 30956118f25ae50e8427c75bb1776d9580a98cf7 Mon Sep 17 00:00:00 2001
From: Christian Cunningham <cc@localhost>
Date: Mon, 14 Feb 2022 11:21:37 -0700
Subject: Priority Inversion Protection

---
 include/cpu.h      |  1 +
 src/sys/schedule.c | 30 +++++++++++++++++++++++++++++-
 src/tests/test.c   | 26 +++++++++++++++++++++-----
 3 files changed, 51 insertions(+), 6 deletions(-)

diff --git a/include/cpu.h b/include/cpu.h
index 6dbaa74..9bda3e8 100644
--- a/include/cpu.h
+++ b/include/cpu.h
@@ -90,6 +90,7 @@ static inline void* getirqstack(void)
 #define SYS_YIELD       0
 #define SYS_TIME        1
 #define SYS_SCHED       2
+#define SYS_YIELD_HIGH  2
 #define SYS_FREE_STACK  3
 #define SYS_LOCK        4
 #define SYS_UNLOCK      5
diff --git a/src/sys/schedule.c b/src/sys/schedule.c
index 1643670..913d66e 100644
--- a/src/sys/schedule.c
+++ b/src/sys/schedule.c
@@ -105,6 +105,10 @@ void add_thread(void* pc, void* arg, unsigned char priority)
 	}
 	trb->queue[trb->woffset++] = thread;
 	trb->woffset %= TQUEUE_MAX;
+	unsigned long mode = getmode() & 0x1F;
+	if (mode == 0x10) {
+		sys0(SYS_YIELD_HIGH);
+	}
 }
 
 void uart_scheduler(void)
@@ -187,6 +191,30 @@ void sched_mutex_yield(void* m)
 	trb->roffset %= TQUEUE_MAX;
 	trbm->queue[trbm->woffset++] = rthread;
 	trbm->woffset %= TQUEUE_MAX;
+	for (int p = 0; p < PRIORITIES; p++) {
+		struct ThreadRotBuffer* trbm = &scheduler.thread_queues[p].mwait;
+		unsigned long roffset = trbm->roffset;
+		while (roffset != trbm->woffset) {
+			if (trbm->queue[roffset]->mptr == m && trbm->queue[roffset] != rthread) {
+				trb->queue[trb->woffset++] = trbm->queue[roffset];
+				trb->woffset %= TQUEUE_MAX;
+				// Copy all next backward to fill space
+				unsigned long coffset = roffset;
+				while (coffset != trbm->woffset) {
+					trbm->queue[coffset] = trbm->queue[(coffset+1)%TQUEUE_MAX];
+					coffset++;
+					coffset %= TQUEUE_MAX;
+				}
+				if(trbm->woffset == 0)
+					trbm->woffset = TQUEUE_MAX-1;
+				else
+					trbm->woffset--;
+				return;
+			}
+			roffset++;
+			roffset %= TQUEUE_MAX;
+		}
+	}
 }
 
 void sched_mutex_resurrect(void* m)
@@ -197,7 +225,7 @@ void sched_mutex_resurrect(void* m)
 		while (roffset != trbm->woffset) {
 			if (trbm->queue[roffset]->mptr == m) {
 				trbm->queue[roffset]->mptr = 0;
-				struct ThreadRotBuffer* trb = &scheduler.thread_queues[p].ready;
+				struct ThreadRotBuffer* trb = &scheduler.thread_queues[trbm->queue[roffset]->priority].ready;
 				trb->queue[trb->woffset++] = trbm->queue[roffset];
 				trb->woffset %= TQUEUE_MAX;
 				// Copy all next backward to fill space
diff --git a/src/tests/test.c b/src/tests/test.c
index 612c3e0..d828163 100644
--- a/src/tests/test.c
+++ b/src/tests/test.c
@@ -36,31 +36,47 @@ void test_entry(void)
 //static struct Mutex testm = {.addr = 0, .pid = 0};
 static struct Lock testm = {.pid = 0};
 
+void ctest1(void);
+void ctest2(void);
+void ctest3(void);
+void ctest4(void);
+
 void ctest1(void)
 {
 	uart_string("1 Started\n");
+	uart_string("1 Locking\n");
 	lock(&testm);
+	add_thread(ctest3, 0, 3);
+	add_thread(ctest2, 0, 2);
+	uart_string("1 Unlocking\n");
+	unlock(&testm);
 	uart_string("1 Finished\n");
 }
 
 void ctest2(void)
 {
 	uart_string("2 Started\n");
+	add_thread(ctest4, 0, 3);
+	uart_string("2 Locking\n");
 	lock(&testm);
-	uart_string("2 Finished\n");
+	uart_string("2 Unlocking\n");
 	unlock(&testm);
+	uart_string("2 Finished\n");
 }
 
 void ctest3(void)
 {
 	uart_string("3 Started\n");
-	unlock(&testm);
 	uart_string("3 Finished\n");
 }
 
+void ctest4(void)
+{
+	uart_string("4 Started\n");
+	uart_string("4 Finished\n");
+}
+
 void btest(void)
 {
-	add_thread(ctest1, 0, 1);
-	add_thread(ctest2, 0, 2);
-	add_thread(ctest3, 0, 3);
+	add_thread(ctest1, 0, 3);
 }
-- 
cgit v1.2.1