update for HEAD-2003091401
[reactos.git] / ntoskrnl / ps / debug.c
index 963b195..a5b1662 100644 (file)
@@ -44,6 +44,7 @@
 #include <string.h>
 #include <internal/ps.h>
 #include <internal/ob.h>
+#include <internal/safe.h>
 
 #define NDEBUG
 #include <internal/debug.h>
@@ -54,7 +55,7 @@ VOID
 KeContextToTrapFrame(PCONTEXT Context,
                     PKTRAP_FRAME TrapFrame)
 {
-   if (Context->ContextFlags & CONTEXT_CONTROL)
+   if ((Context->ContextFlags & CONTEXT_CONTROL) == CONTEXT_CONTROL)
      {
        TrapFrame->Esp = Context->Esp;
        TrapFrame->Ss = Context->SegSs;
@@ -63,7 +64,7 @@ KeContextToTrapFrame(PCONTEXT Context,
        TrapFrame->Eflags = Context->EFlags;    
        TrapFrame->Ebp = Context->Ebp;
      }
-   if (Context->ContextFlags & CONTEXT_INTEGER)
+   if ((Context->ContextFlags & CONTEXT_INTEGER) == CONTEXT_INTEGER)
      {
        TrapFrame->Eax = Context->Eax;
        TrapFrame->Ebx = Context->Ebx;
@@ -76,20 +77,20 @@ KeContextToTrapFrame(PCONTEXT Context,
        TrapFrame->Esi = Context->Esi;
        TrapFrame->Edi = Context->Edi;
      }
-   if (Context->ContextFlags & CONTEXT_SEGMENTS)
+   if ((Context->ContextFlags & CONTEXT_SEGMENTS) == CONTEXT_SEGMENTS)
      {
        TrapFrame->Ds = Context->SegDs;
        TrapFrame->Es = Context->SegEs;
        TrapFrame->Fs = Context->SegFs;
        TrapFrame->Gs = Context->SegGs;
      }
-   if (Context->ContextFlags & CONTEXT_FLOATING_POINT)
+   if ((Context->ContextFlags & CONTEXT_FLOATING_POINT) == CONTEXT_FLOATING_POINT)
      {
        /*
         * Not handled
         */
      }
-   if (Context->ContextFlags & CONTEXT_DEBUG_REGISTERS)
+   if ((Context->ContextFlags & CONTEXT_DEBUG_REGISTERS) == CONTEXT_DEBUG_REGISTERS)
      {
        /*
         * Not handled
@@ -101,7 +102,7 @@ VOID
 KeTrapFrameToContext(PKTRAP_FRAME TrapFrame,
                     PCONTEXT Context)
 {
-   if (Context->ContextFlags & CONTEXT_CONTROL)
+   if ((Context->ContextFlags & CONTEXT_CONTROL) == CONTEXT_CONTROL)
      {
        Context->SegSs = TrapFrame->Ss;
        Context->Esp = TrapFrame->Esp;
@@ -110,7 +111,7 @@ KeTrapFrameToContext(PKTRAP_FRAME TrapFrame,
        Context->EFlags = TrapFrame->Eflags;
        Context->Ebp = TrapFrame->Ebp;
      }
-   if (Context->ContextFlags & CONTEXT_INTEGER)
+   if ((Context->ContextFlags & CONTEXT_INTEGER) == CONTEXT_INTEGER)
      {
        Context->Eax = TrapFrame->Eax;
        Context->Ebx = TrapFrame->Ebx;
@@ -127,27 +128,27 @@ KeTrapFrameToContext(PKTRAP_FRAME TrapFrame,
        Context->Esi = TrapFrame->Esi;
        Context->Edi = TrapFrame->Edi;
      }
-   if (Context->ContextFlags & CONTEXT_SEGMENTS)
+   if ((Context->ContextFlags & CONTEXT_SEGMENTS) == CONTEXT_SEGMENTS)
      {
        Context->SegDs = TrapFrame->Ds;
        Context->SegEs = TrapFrame->Es;
        Context->SegFs = TrapFrame->Fs;
        Context->SegGs = TrapFrame->Gs;
      }
-   if (Context->ContextFlags & CONTEXT_DEBUG_REGISTERS)
+   if ((Context->ContextFlags & CONTEXT_DEBUG_REGISTERS) == CONTEXT_DEBUG_REGISTERS)
      {
        /*
         * FIXME: Implement this case
         */     
      }
-   if (Context->ContextFlags & CONTEXT_FLOATING_POINT)
+   if ((Context->ContextFlags & CONTEXT_FLOATING_POINT) == CONTEXT_FLOATING_POINT)
      {
        /*
         * FIXME: Implement this case
         */
      }
 #if 0
-   if (Context->ContextFlags & CONTEXT_EXTENDED_REGISTERS)
+   if ((Context->ContextFlags & CONTEXT_EXTENDED_REGISTERS) == CONTEXT_EXTENDED_REGISTERS)
      {
        /*
         * FIXME: Investigate this
@@ -157,15 +158,15 @@ KeTrapFrameToContext(PKTRAP_FRAME TrapFrame,
 }
 
 VOID STDCALL
-KeGetContextRundownRoutine(PKAPC Apc)
+KeGetSetContextRundownRoutine(PKAPC Apc)
 {
-   PKEVENT Event;
-   PNTSTATUS Status;
-   
-   Event = (PKEVENT)Apc->SystemArgument1;
-   Status = (PNTSTATUS)Apc->SystemArgument2;
-   (*Status) = STATUS_THREAD_IS_TERMINATING;
-   KeSetEvent(Event, IO_NO_INCREMENT, FALSE);
+  PKEVENT Event;
+  PNTSTATUS Status;
+
+  Event = (PKEVENT)Apc->SystemArgument1;   
+  Status = (PNTSTATUS)Apc->SystemArgument2;
+  (*Status) = STATUS_THREAD_IS_TERMINATING;
+  KeSetEvent(Event, IO_NO_INCREMENT, FALSE);
 }
 
 VOID STDCALL
@@ -179,97 +180,200 @@ KeGetContextKernelRoutine(PKAPC Apc,
  * copy the context of a thread into a buffer.
  */
 {
-   PKEVENT Event;
-   PCONTEXT Context;
-   PNTSTATUS Status;
+  PKEVENT Event;
+  PCONTEXT Context;
+  PNTSTATUS Status;
    
-   Context = (PCONTEXT)(*NormalContext);
-   Event = (PKEVENT)(*SystemArgument1);
-   Status = (PNTSTATUS)(*SystemArgument2);
+  Context = (PCONTEXT)(*NormalContext);
+  Event = (PKEVENT)(*SystemArgument1);
+  Status = (PNTSTATUS)(*SystemArgument2);
    
-   KeTrapFrameToContext(KeGetCurrentThread()->TrapFrame, Context);
+  KeTrapFrameToContext(KeGetCurrentThread()->TrapFrame, Context);
    
-   *Status = STATUS_SUCCESS;
-   KeSetEvent(Event, IO_NO_INCREMENT, FALSE);
+  *Status = STATUS_SUCCESS;
+  KeSetEvent(Event, IO_NO_INCREMENT, FALSE);
 }
 
 NTSTATUS STDCALL
 NtGetContextThread(IN HANDLE ThreadHandle,
-                  OUT PCONTEXT Context)
+                  OUT PCONTEXT UnsafeContext)
 {
-   PETHREAD Thread;
-   NTSTATUS Status;
-   
-   Status = ObReferenceObjectByHandle(ThreadHandle,
-                                     THREAD_GET_CONTEXT,
-                                     PsThreadType,
-                                     UserMode,
-                                     (PVOID*)&Thread,
-                                     NULL);
-   if (!NT_SUCCESS(Status))
-     {
-       return(Status);
-     }
-   if (Thread == PsGetCurrentThread())
-     {
-       /*
-        * I don't know if trying to get your own context makes much
-        * sense but we can handle it more efficently.
-        */
-       
-       KeTrapFrameToContext(Thread->Tcb.TrapFrame, Context);
-       ObDereferenceObject(Thread);
-       return(STATUS_SUCCESS);
-     }
-   else
-     {
-       KAPC Apc;
-       KEVENT Event;
-       NTSTATUS AStatus;
-       CONTEXT KContext;
+  PETHREAD Thread;
+  NTSTATUS Status;
+  CONTEXT Context;
+  KAPC Apc;
+  KEVENT Event;
+  NTSTATUS AStatus;
+
+  Status = MmCopyFromCaller(&Context, UnsafeContext, sizeof(CONTEXT));
+  if (! NT_SUCCESS(Status))
+    {
+      return Status;
+    }
+  Status = ObReferenceObjectByHandle(ThreadHandle,
+                                     THREAD_GET_CONTEXT,
+                                     PsThreadType,
+                                     UserMode,
+                                     (PVOID*)&Thread,
+                                     NULL);
+  if (! NT_SUCCESS(Status))
+    {
+      return Status;
+    }
+  if (Thread == PsGetCurrentThread())
+    {
+      /*
+       * I don't know if trying to get your own context makes much
+       * sense but we can handle it more efficently.
+       */
        
-       KContext.ContextFlags = Context->ContextFlags;
-       KeInitializeEvent(&Event,
-                         NotificationEvent,
-                         FALSE);       
-       AStatus = STATUS_SUCCESS;
+      KeTrapFrameToContext(Thread->Tcb.TrapFrame, &Context);
+      Status = STATUS_SUCCESS;
+    }
+  else
+    {
+      KeInitializeEvent(&Event,
+                        NotificationEvent,
+                        FALSE);        
+      AStatus = STATUS_SUCCESS;
        
-       KeInitializeApc(&Apc,
-                       &Thread->Tcb,
-                       0,
-                       KeGetContextKernelRoutine,
-                       KeGetContextRundownRoutine,
-                       NULL,
-                       KernelMode,
-                       (PVOID)&KContext);
-       KeInsertQueueApc(&Apc,
-                        (PVOID)&Event,
-                        (PVOID)&AStatus,
-                        0);
-       Status = KeWaitForSingleObject(&Event,
-                                      0,
-                                      UserMode,
-                                      FALSE,
-                                      NULL);
-       if (!NT_SUCCESS(Status))
-         {
-            return(Status);
-         }
-       if (!NT_SUCCESS(AStatus))
-         {
-            return(AStatus);
-         }
-       memcpy(Context, &KContext, sizeof(CONTEXT));
-       ObDereferenceObject(Thread);
-       return(STATUS_SUCCESS);
-     }
+      KeInitializeApc(&Apc,
+                      &Thread->Tcb,
+                      OriginalApcEnvironment,
+                      KeGetContextKernelRoutine,
+                      KeGetSetContextRundownRoutine,
+                      NULL,
+                      KernelMode,
+                      (PVOID)&Context);
+      if (!KeInsertQueueApc(&Apc,
+                           (PVOID)&Event,
+                           (PVOID)&AStatus,
+                           IO_NO_INCREMENT))
+       {
+         Status = STATUS_THREAD_IS_TERMINATING;
+       }
+      else
+       {
+         Status = KeWaitForSingleObject(&Event,
+                                        0,
+                                        UserMode,
+                                        FALSE,
+                                        NULL);
+         if (NT_SUCCESS(Status) && !NT_SUCCESS(AStatus))
+           {
+             Status = AStatus;
+           }
+       }
+    }
+  if (NT_SUCCESS(Status))
+    {
+      Status = MmCopyToCaller(UnsafeContext, &Context, sizeof(Context));
+    }
+
+  ObDereferenceObject(Thread);
+  return Status;
+}
+
+VOID STDCALL
+KeSetContextKernelRoutine(PKAPC Apc,
+                         PKNORMAL_ROUTINE* NormalRoutine,
+                         PVOID* NormalContext,
+                         PVOID* SystemArgument1,
+                         PVOID* SystemArgument2)
+/*
+ * FUNCTION: This routine is called by an APC sent by NtSetContextThread to
+ * set the context of a thread from a buffer.
+ */
+{
+  PKEVENT Event;
+  PCONTEXT Context;
+  PNTSTATUS Status;
+   
+  Context = (PCONTEXT)(*NormalContext);
+  Event = (PKEVENT)(*SystemArgument1);
+  Status = (PNTSTATUS)(*SystemArgument2);
+   
+  KeContextToTrapFrame(Context, KeGetCurrentThread()->TrapFrame);
+   
+  *Status = STATUS_SUCCESS;
+  KeSetEvent(Event, IO_NO_INCREMENT, FALSE);
 }
 
 NTSTATUS STDCALL
 NtSetContextThread(IN HANDLE ThreadHandle,
-                  IN PCONTEXT Context)
+                  IN PCONTEXT UnsafeContext)
 {
-   UNIMPLEMENTED;
+  PETHREAD Thread;
+  NTSTATUS Status;
+  KAPC Apc;
+  KEVENT Event;
+  NTSTATUS AStatus;
+  CONTEXT Context;
+
+  Status = MmCopyFromCaller(&Context, UnsafeContext, sizeof(CONTEXT));
+  if (! NT_SUCCESS(Status))
+    {
+      return Status;
+    }
+  Status = ObReferenceObjectByHandle(ThreadHandle,
+                                     THREAD_SET_CONTEXT,
+                                     PsThreadType,
+                                     UserMode,
+                                     (PVOID*)&Thread,
+                                     NULL);
+  if (!NT_SUCCESS(Status))
+    {
+      return Status;
+    }
+
+  if (Thread == PsGetCurrentThread())
+    {
+      /*
+       * I don't know if trying to set your own context makes much
+       * sense but we can handle it more efficently.
+       */
+       
+      KeContextToTrapFrame(&Context, Thread->Tcb.TrapFrame);
+      Status = STATUS_SUCCESS;
+    }
+  else
+    {
+      KeInitializeEvent(&Event,
+                        NotificationEvent,
+                        FALSE);        
+      AStatus = STATUS_SUCCESS;
+       
+      KeInitializeApc(&Apc,
+                      &Thread->Tcb,
+                      OriginalApcEnvironment,
+                      KeSetContextKernelRoutine,
+                      KeGetSetContextRundownRoutine,
+                      NULL,
+                      KernelMode,
+                      (PVOID)&Context);
+      if (!KeInsertQueueApc(&Apc,
+                           (PVOID)&Event,
+                           (PVOID)&AStatus,
+                           IO_NO_INCREMENT))
+       {
+         Status = STATUS_THREAD_IS_TERMINATING;
+       }
+      else
+       {
+         Status = KeWaitForSingleObject(&Event,
+                                        0,
+                                        UserMode,
+                                        FALSE,
+                                     NULL);
+         if (NT_SUCCESS(Status) && !NT_SUCCESS(AStatus))
+           {
+             Status = AStatus;
+           }
+       }
+    }
+
+  ObDereferenceObject(Thread);
+  return Status;
 }
 
 /* EOF */