diff --git a/dll/win32/msafd/misc/dllmain.c b/dll/win32/msafd/misc/dllmain.c index 79623a142c0..06d0890f36b 100644 --- a/dll/win32/msafd/misc/dllmain.c +++ b/dll/win32/msafd/misc/dllmain.c @@ -418,7 +418,7 @@ ok: return (SOCKET)Sock; error: - ERR("Ending %x\n", Status); + TRACE("Ending %x\n", Status); if( SharedData ) { @@ -610,6 +610,20 @@ TranslateNtStatusError(NTSTATUS Status) } } +VOID +MsafdWaitForAlert( + _In_ HANDLE hObject) +{ + DWORD Result; + + for (;;) + { + Result = WaitForSingleObjectEx(hObject, INFINITE, TRUE); + if (Result != WAIT_IO_COMPLETION) + break; + } +} + /* * FUNCTION: Closes an open socket * ARGUMENTS: @@ -640,18 +654,8 @@ WSPCloseSocket(IN SOCKET Handle, return SOCKET_ERROR; } - /* Create the Wait Event */ - Status = NtCreateEvent(&SockEvent, - EVENT_ALL_ACCESS, - NULL, - SynchronizationEvent, - FALSE); - - if(!NT_SUCCESS(Status)) - { - ERR("NtCreateEvent failed: 0x%08x\n", Status); - return SOCKET_ERROR; - } + /* HACK: Allow APC to be processed */ + SleepEx(0, TRUE); if (Socket->HelperEvents & WSH_NOTIFY_CLOSE) { @@ -663,9 +667,8 @@ WSPCloseSocket(IN SOCKET Handle, if (Status) { - if (lpErrno) *lpErrno = Status; ERR("WSHNotify failed. Error 0x%#x\n", Status); - NtClose(SockEvent); + if (lpErrno) *lpErrno = Status; return SOCKET_ERROR; } } @@ -674,11 +677,22 @@ WSPCloseSocket(IN SOCKET Handle, if (Socket->SharedData->State == SocketClosed) { WARN("Socket is closing.\n"); - NtClose(SockEvent); if (lpErrno) *lpErrno = WSAENOTSOCK; return SOCKET_ERROR; } + /* Create the Wait Event */ + Status = NtCreateEvent(&SockEvent, + EVENT_ALL_ACCESS, + NULL, + SynchronizationEvent, + FALSE); + if (!NT_SUCCESS(Status)) + { + ERR("NtCreateEvent failed: 0x%08x\n", Status); + return SOCKET_ERROR; + } + /* Decrement reference count on SharedData */ References = InterlockedDecrement(&Socket->SharedData->RefCount); if (References) @@ -775,7 +789,7 @@ WSPCloseSocket(IN SOCKET Handle, /* Wait for return */ if (Status == STATUS_PENDING) { - WaitForSingleObject(SockEvent, INFINITE); + MsafdWaitForAlert(SockEvent); Status = IoStatusBlock.Status; } } @@ -943,7 +957,7 @@ WSPBind(SOCKET Handle, /* Wait for return */ if (Status == STATUS_PENDING) { - WaitForSingleObject(SockEvent, INFINITE); + MsafdWaitForAlert(SockEvent); Status = IOSB.Status; } @@ -1028,7 +1042,7 @@ WSPListen(SOCKET Handle, /* Wait for return */ if (Status == STATUS_PENDING) { - WaitForSingleObject(SockEvent, INFINITE); + MsafdWaitForAlert(SockEvent); Status = IOSB.Status; } @@ -1304,7 +1318,7 @@ WSPSelect(IN int nfds, /* Wait for Completion */ if (Status == STATUS_PENDING) { - WaitForSingleObject(SockEvent, INFINITE); + MsafdWaitForAlert(SockEvent); Status = IOSB.Status; } @@ -1531,7 +1545,7 @@ WSPAccept( /* Wait for return */ if (Status == STATUS_PENDING) { - WaitForSingleObject(SockEvent, INFINITE); + MsafdWaitForAlert(SockEvent); Status = IOSB.Status; } @@ -1564,7 +1578,7 @@ WSPAccept( /* Wait for return */ if (Status == STATUS_PENDING) { - WaitForSingleObject(SockEvent, INFINITE); + MsafdWaitForAlert(SockEvent); Status = IOSB.Status; } @@ -1604,7 +1618,7 @@ WSPAccept( /* Wait for return */ if (Status == STATUS_PENDING) { - WaitForSingleObject(SockEvent, INFINITE); + MsafdWaitForAlert(SockEvent); Status = IOSB.Status; } @@ -1709,7 +1723,7 @@ WSPAccept( /* Wait for return */ if (Status == STATUS_PENDING) { - WaitForSingleObject(SockEvent, INFINITE); + MsafdWaitForAlert(SockEvent); Status = IOSB.Status; } @@ -1763,7 +1777,7 @@ WSPAccept( /* Wait for return */ if (Status == STATUS_PENDING) { - WaitForSingleObject(SockEvent, INFINITE); + MsafdWaitForAlert(SockEvent); Status = IOSB.Status; } @@ -1825,6 +1839,60 @@ WSPAccept( return AcceptSocket; } +static +VOID +NTAPI +MsafdConnectAPC( + _In_ PVOID ApcContext, + _In_ PIO_STATUS_BLOCK IoStatusBlock, + _In_ ULONG Reserved) +{ + PMSAFD_CONNECT_APC_CONTEXT Context = ApcContext; + + TRACE("MsafdConnectAPC(%p %lx %lx)\n", ApcContext, IoStatusBlock->Status, IoStatusBlock->Information); + + PSOCKET_INFORMATION Socket = GetSocketStructure(Context->lpSocket); + if (!Socket) + { + // FIXME: Socket is closed before this APC could run + HeapFree(GlobalHeap, 0, ApcContext); + return; + } + + Socket->SharedData->SocketLastError = TranslateNtStatusError(IoStatusBlock->Status); + if (IoStatusBlock->Status == STATUS_SUCCESS) + { + Socket->SharedData->State = SocketConnected; + Socket->TdiConnectionHandle = (HANDLE)IoStatusBlock->Information; + Socket->SharedData->ConnectTime = GetCurrentTimeInSeconds(); + } + + /* Re-enable Async Event */ + SockReenableAsyncSelectEvent(Socket, FD_WRITE); + + /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */ + SockReenableAsyncSelectEvent(Socket, FD_CONNECT); + + if (IoStatusBlock->Status == STATUS_SUCCESS && (Socket->HelperEvents & WSH_NOTIFY_CONNECT)) + { + Socket->HelperData->WSHNotify(Socket->HelperContext, + Socket->Handle, + Socket->TdiAddressHandle, + Socket->TdiConnectionHandle, + WSH_NOTIFY_CONNECT); + } + else if (IoStatusBlock->Status != STATUS_SUCCESS && (Socket->HelperEvents & WSH_NOTIFY_CONNECT_ERROR)) + { + Socket->HelperData->WSHNotify(Socket->HelperContext, + Socket->Handle, + Socket->TdiAddressHandle, + Socket->TdiConnectionHandle, + WSH_NOTIFY_CONNECT_ERROR); + } + + HeapFree(GlobalHeap, 0, ApcContext); +} + int WSPAPI WSPConnect(SOCKET Handle, @@ -1836,19 +1904,21 @@ WSPConnect(SOCKET Handle, LPQOS lpGQOS, LPINT lpErrno) { - IO_STATUS_BLOCK IOSB; - PAFD_CONNECT_INFO ConnectInfo = NULL; - PSOCKET_INFORMATION Socket; - NTSTATUS Status; - INT Errno; - ULONG ConnectDataLength; - ULONG InConnectDataLength; - INT BindAddressLength; - PSOCKADDR BindAddress; - HANDLE SockEvent; - int SocketDataLength; + IO_STATUS_BLOCK DummyIOSB; + PIO_STATUS_BLOCK IOSB = &DummyIOSB; + PAFD_CONNECT_INFO ConnectInfo = NULL; + PSOCKET_INFORMATION Socket; + NTSTATUS Status; + INT Errno; + ULONG ConnectDataLength; + ULONG InConnectDataLength; + HANDLE SockEvent; + int SocketDataLength; + PMSAFD_CONNECT_APC_CONTEXT APCContext = NULL; + PIO_APC_ROUTINE APCFunction = NULL; + UCHAR Buffer[128]; - TRACE("Called (%lx) %lx:%d\n", Handle, ((const struct sockaddr_in *)SocketAddress)->sin_addr, ((const struct sockaddr_in *)SocketAddress)->sin_port); + TRACE("WSPConnect(%x)\n", Handle); /* Get the Socket Structure associate to this Socket*/ Socket = GetSocketStructure(Handle); @@ -1858,34 +1928,38 @@ WSPConnect(SOCKET Handle, return SOCKET_ERROR; } - Status = NtCreateEvent(&SockEvent, - EVENT_ALL_ACCESS, - NULL, - SynchronizationEvent, - FALSE); - - if (!NT_SUCCESS(Status)) - return SOCKET_ERROR; - /* Bind us First */ if (Socket->SharedData->State == SocketOpen) { + INT BindAddressLength; + PSOCKADDR BindAddress; + INT BindError; + /* Get the Wildcard Address */ BindAddressLength = Socket->HelperData->MaxWSAddressLength; BindAddress = HeapAlloc(GetProcessHeap(), 0, BindAddressLength); if (!BindAddress) { - NtClose(SockEvent); return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL); } - Socket->HelperData->WSHGetWildcardSockaddr (Socket->HelperContext, - BindAddress, - &BindAddressLength); + Socket->HelperData->WSHGetWildcardSockaddr(Socket->HelperContext, + BindAddress, + &BindAddressLength); /* Bind it */ - if (WSPBind(Handle, BindAddress, BindAddressLength, lpErrno) == SOCKET_ERROR) + BindError = WSPBind(Handle, BindAddress, BindAddressLength, lpErrno); + HeapFree(GetProcessHeap(), 0, BindAddress); + if (BindError == SOCKET_ERROR) return SOCKET_ERROR; } + Status = NtCreateEvent(&SockEvent, + EVENT_ALL_ACCESS, + NULL, + SynchronizationEvent, + FALSE); + if (!NT_SUCCESS(Status)) + return SOCKET_ERROR; + /* Set the Connect Data */ if (lpCallerData != NULL) { @@ -1894,7 +1968,7 @@ WSPConnect(SOCKET Handle, SockEvent, NULL, NULL, - &IOSB, + IOSB, IOCTL_AFD_SET_CONNECT_DATA, lpCallerData->buf, ConnectDataLength, @@ -1903,25 +1977,25 @@ WSPConnect(SOCKET Handle, /* Wait for return */ if (Status == STATUS_PENDING) { - WaitForSingleObject(SockEvent, INFINITE); - Status = IOSB.Status; + MsafdWaitForAlert(SockEvent); + Status = IOSB->Status; } if (Status != STATUS_SUCCESS) - goto notify; + goto Leave; } /* Calculate the size of SocketAddress->sa_data */ SocketDataLength = SocketAddressLength - FIELD_OFFSET(struct sockaddr, sa_data); - /* Allocate a connection info buffer with SocketDataLength bytes of payload */ - ConnectInfo = HeapAlloc(GetProcessHeap(), 0, - FIELD_OFFSET(AFD_CONNECT_INFO, - RemoteAddress.Address[0].Address[SocketDataLength])); - if (!ConnectInfo) + ConnectInfo = (PAFD_CONNECT_INFO)Buffer; + + int connectionInfoSize = FIELD_OFFSET(AFD_CONNECT_INFO, RemoteAddress.Address[0].Address[SocketDataLength]); + + if (connectionInfoSize > 128) { - Status = STATUS_INSUFFICIENT_RESOURCES; - goto notify; + *lpErrno = WSAEFAULT; + return SOCKET_ERROR; } /* Set up Address in TDI Format */ @@ -1951,7 +2025,7 @@ WSPConnect(SOCKET Handle, SockEvent, NULL, NULL, - &IOSB, + IOSB, IOCTL_AFD_SET_CONNECT_DATA_SIZE, &InConnectDataLength, sizeof(InConnectDataLength), @@ -1961,12 +2035,12 @@ WSPConnect(SOCKET Handle, /* Wait for return */ if (Status == STATUS_PENDING) { - WaitForSingleObject(SockEvent, INFINITE); - Status = IOSB.Status; + MsafdWaitForAlert(SockEvent); + Status = IOSB->Status; } if (Status != STATUS_SUCCESS) - goto notify; + goto Leave; } /* AFD doesn't seem to care if these are invalid, but let's 0 them anyways */ @@ -1974,46 +2048,76 @@ WSPConnect(SOCKET Handle, ConnectInfo->UseSAN = FALSE; ConnectInfo->Unknown = 0; - /* FIXME: Handle Async Connect */ + /* Verify if we should use APC */ if (Socket->SharedData->NonBlocking) { - ERR("Async Connect UNIMPLEMENTED!\n"); + APCContext = HeapAlloc(GlobalHeap, 0, sizeof(*APCContext)); + if (!APCContext) + { + ERR("Not enough memory for APC Context\n"); + Status = STATUS_INSUFFICIENT_RESOURCES; + goto Leave; + } + APCContext->lpSocket = Handle; + APCFunction = &MsafdConnectAPC; + + IOSB = &APCContext->IoStatusBlock; } + IOSB->Status = STATUS_PENDING; + IOSB->Information = 0; + /* Send IOCTL */ Status = NtDeviceIoControlFile((HANDLE)Handle, - SockEvent, - NULL, - NULL, - &IOSB, + APCFunction ? NULL : SockEvent, + APCFunction, + APCContext, + IOSB, IOCTL_AFD_CONNECT, ConnectInfo, - 0x22, + connectionInfoSize, NULL, 0); - /* Wait for return */ - if (Status == STATUS_PENDING) + if (Socket->SharedData->NonBlocking) { - WaitForSingleObject(SockEvent, INFINITE); - Status = IOSB.Status; + if (Status == STATUS_PENDING) + { + Status = STATUS_CANT_WAIT; // WSAEWOULDBLOCK + goto Leave; + } + else + { + /* HACK: Allow APC to be processed */ + SleepEx(0, TRUE); + } + } + else + { + /* Wait for completion if blocking */ + if (Status == STATUS_PENDING) + { + MsafdWaitForAlert(SockEvent); + Status = IOSB->Status; + } + + Socket->SharedData->SocketLastError = TranslateNtStatusError(Status); + if (Status != STATUS_SUCCESS) + goto Leave; + + Socket->SharedData->State = SocketConnected; + Socket->TdiConnectionHandle = (HANDLE)IOSB->Information; + Socket->SharedData->ConnectTime = GetCurrentTimeInSeconds(); } - Socket->SharedData->SocketLastError = TranslateNtStatusError(Status); - if (Status != STATUS_SUCCESS) - goto notify; - - Socket->SharedData->State = SocketConnected; - Socket->TdiConnectionHandle = (HANDLE)IOSB.Information; - Socket->SharedData->ConnectTime = GetCurrentTimeInSeconds(); - /* Get any pending connect data */ - if (lpCalleeData != NULL) + if (lpCalleeData != NULL && Status == STATUS_SUCCESS) { + IOSB = &DummyIOSB; Status = NtDeviceIoControlFile((HANDLE)Handle, SockEvent, NULL, NULL, - &IOSB, + IOSB, IOCTL_AFD_GET_CONNECT_DATA, NULL, 0, @@ -2022,15 +2126,15 @@ WSPConnect(SOCKET Handle, /* Wait for return */ if (Status == STATUS_PENDING) { - WaitForSingleObject(SockEvent, INFINITE); - Status = IOSB.Status; + MsafdWaitForAlert(SockEvent); + Status = IOSB->Status; } } - TRACE("Ending %lx\n", IOSB.Status); +Leave: + TRACE("Ending %lx\n", Status); -notify: - if (ConnectInfo) HeapFree(GetProcessHeap(), 0, ConnectInfo); + NtClose(SockEvent); /* Re-enable Async Event */ SockReenableAsyncSelectEvent(Socket, FD_WRITE); @@ -2038,8 +2142,6 @@ notify: /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */ SockReenableAsyncSelectEvent(Socket, FD_CONNECT); - NtClose(SockEvent); - if (Status == STATUS_SUCCESS && (Socket->HelperEvents & WSH_NOTIFY_CONNECT)) { Errno = Socket->HelperData->WSHNotify(Socket->HelperContext, @@ -2138,7 +2240,7 @@ WSPShutdown(SOCKET Handle, /* Wait for return */ if (Status == STATUS_PENDING) { - WaitForSingleObject(SockEvent, INFINITE); + MsafdWaitForAlert(SockEvent); Status = IOSB.Status; } @@ -2218,7 +2320,7 @@ WSPGetSockName(IN SOCKET Handle, /* Wait for return */ if (Status == STATUS_PENDING) { - WaitForSingleObject(SockEvent, INFINITE); + MsafdWaitForAlert(SockEvent); Status = IOSB.Status; } @@ -2322,7 +2424,7 @@ WSPGetPeerName(IN SOCKET s, /* Wait for return */ if (Status == STATUS_PENDING) { - WaitForSingleObject(SockEvent, INFINITE); + MsafdWaitForAlert(SockEvent); Status = IOSB.Status; } @@ -3254,15 +3356,24 @@ WSPCleanup(OUT LPINT lpErrno) return 0; } +static VOID NTAPI -AfdInfoAPC(PVOID ApcContext, - PIO_STATUS_BLOCK IoStatusBlock, - ULONG Reserved) +MsafdInfoAPC( + _In_ PVOID ApcContext, + _In_ PIO_STATUS_BLOCK IoStatusBlock, + _In_ ULONG Reserved) { - PAFDAPCCONTEXT Context = ApcContext; + PMSAFD_INFO_APC_CONTEXT Context = ApcContext; + + TRACE("MsafdInfoAPC(%p %lx %lx)\n", ApcContext, IoStatusBlock->Status, IoStatusBlock->Information); + + if (Context->lpCompletionRoutine) + Context->lpCompletionRoutine(IoStatusBlock->Status, IoStatusBlock->Information, Context->lpOverlapped, 0); + + /* Free IOCTL buffer */ + HeapFree(GlobalHeap, 0, Context->lpInfoData); - Context->lpCompletionRoutine(IoStatusBlock->Status, IoStatusBlock->Information, Context->lpOverlapped, 0); HeapFree(GlobalHeap, 0, ApcContext); } @@ -3275,29 +3386,35 @@ GetSocketInformation(PSOCKET_INFORMATION Socket, LPWSAOVERLAPPED Overlapped OPTIONAL, LPWSAOVERLAPPED_COMPLETION_ROUTINE CompletionRoutine OPTIONAL) { - PIO_STATUS_BLOCK IOSB; - IO_STATUS_BLOCK DummyIOSB; - AFD_INFO InfoData; - NTSTATUS Status; - PAFDAPCCONTEXT APCContext; - PIO_APC_ROUTINE APCFunction; - HANDLE Event = NULL; - HANDLE SockEvent; + PIO_STATUS_BLOCK IOSB; + IO_STATUS_BLOCK DummyIOSB; + PAFD_INFO InfoData; + NTSTATUS Status; + PMSAFD_INFO_APC_CONTEXT APCContext; + PIO_APC_ROUTINE APCFunction; + HANDLE Event = NULL; + HANDLE SockEvent; + + InfoData = HeapAlloc(GlobalHeap, 0, sizeof(*InfoData)); + if (!InfoData) + return SOCKET_ERROR; Status = NtCreateEvent(&SockEvent, EVENT_ALL_ACCESS, NULL, SynchronizationEvent, FALSE); - - if( !NT_SUCCESS(Status) ) + if (!NT_SUCCESS(Status)) + { + HeapFree(GlobalHeap, 0, InfoData); return SOCKET_ERROR; + } /* Set Info Class */ - InfoData.InformationClass = AfdInformationClass; + InfoData->InformationClass = AfdInformationClass; /* Verify if we should use APC */ - if (Overlapped == NULL) + if (!Overlapped) { /* Not using Overlapped structure, so use normal blocking on event */ APCContext = NULL; @@ -3311,31 +3428,26 @@ GetSocketInformation(PSOCKET_INFORMATION Socket, if ((Socket->SharedData->CreateFlags & SO_SYNCHRONOUS_NONALERT) != 0) { TRACE("Opened without flag WSA_FLAG_OVERLAPPED. Do nothing.\n"); - NtClose( SockEvent ); - return 0; + NtClose(SockEvent); + HeapFree(GlobalHeap, 0, InfoData); + return MsafdReturnWithErrno(STATUS_SUCCESS, NULL, 0, NULL); } - if (CompletionRoutine == NULL) + + APCContext = HeapAlloc(GlobalHeap, 0, sizeof(*APCContext)); + if (!APCContext) { - /* Using Overlapped Structure, but no Completition Routine, so no need for APC */ - APCContext = (PAFDAPCCONTEXT)Overlapped; - APCFunction = NULL; + ERR("Not enough memory for APC Context\n"); + NtClose(SockEvent); + HeapFree(GlobalHeap, 0, InfoData); + return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, NULL, 0, NULL); + } + APCContext->lpCompletionRoutine = CompletionRoutine; + APCContext->lpOverlapped = Overlapped; + APCContext->lpInfoData = InfoData; + APCFunction = &MsafdInfoAPC; + + if (!CompletionRoutine) Event = Overlapped->hEvent; - } - else - { - /* Using Overlapped Structure and a Completition Routine, so use an APC */ - APCFunction = &AfdInfoAPC; // should be a private io completition function inside us - APCContext = HeapAlloc(GlobalHeap, 0, sizeof(AFDAPCCONTEXT)); - if (!APCContext) - { - ERR("Not enough memory for APC Context\n"); - NtClose( SockEvent ); - return WSAEFAULT; - } - APCContext->lpCompletionRoutine = CompletionRoutine; - APCContext->lpOverlapped = Overlapped; - APCContext->lpSocket = Socket; - } IOSB = (PIO_STATUS_BLOCK)&Overlapped->Internal; } @@ -3349,47 +3461,45 @@ GetSocketInformation(PSOCKET_INFORMATION Socket, APCContext, IOSB, IOCTL_AFD_GET_INFO, - &InfoData, - sizeof(InfoData), - &InfoData, - sizeof(InfoData)); + InfoData, + sizeof(*InfoData), + InfoData, + sizeof(*InfoData)); - /* Wait for return */ - if (Status == STATUS_PENDING && Overlapped == NULL) + /* Wait for completion if not overlapped */ + if (!Overlapped && Status == STATUS_PENDING) { - WaitForSingleObject(SockEvent, INFINITE); + MsafdWaitForAlert(SockEvent); Status = IOSB->Status; } - NtClose( SockEvent ); + TRACE("Status %lx\n", Status); - TRACE("Status %x Information %d\n", Status, IOSB->Information); - - if (Status == STATUS_PENDING) + if (Status == STATUS_SUCCESS) { - TRACE("Leaving (Pending)\n"); - return WSA_IO_PENDING; + /* Return Information */ + if (Ulong != NULL) + { + *Ulong = InfoData->Information.Ulong; + } + if (LargeInteger != NULL) + { + *LargeInteger = InfoData->Information.LargeInteger; + } + if (Boolean != NULL) + { + *Boolean = InfoData->Information.Boolean; + } } - if (Status != STATUS_SUCCESS) - return SOCKET_ERROR; - - /* Return Information */ - if (Ulong != NULL) + NtClose(SockEvent); + if (!APCFunction) { - *Ulong = InfoData.Information.Ulong; - } - if (LargeInteger != NULL) - { - *LargeInteger = InfoData.Information.LargeInteger; - } - if (Boolean != NULL) - { - *Boolean = InfoData.Information.Boolean; + /* When using APC, this will be freed by the APC function */ + HeapFree(GlobalHeap, 0, InfoData); } - return NO_ERROR; - + return MsafdReturnWithErrno(Status, NULL, 0, NULL); } @@ -3402,43 +3512,49 @@ SetSocketInformation(PSOCKET_INFORMATION Socket, LPWSAOVERLAPPED Overlapped OPTIONAL, LPWSAOVERLAPPED_COMPLETION_ROUTINE CompletionRoutine OPTIONAL) { - PIO_STATUS_BLOCK IOSB; - IO_STATUS_BLOCK DummyIOSB; - AFD_INFO InfoData; - NTSTATUS Status; - PAFDAPCCONTEXT APCContext; - PIO_APC_ROUTINE APCFunction; - HANDLE Event = NULL; - HANDLE SockEvent; + PIO_STATUS_BLOCK IOSB; + IO_STATUS_BLOCK DummyIOSB; + PAFD_INFO InfoData; + NTSTATUS Status; + PMSAFD_INFO_APC_CONTEXT APCContext; + PIO_APC_ROUTINE APCFunction; + HANDLE Event = NULL; + HANDLE SockEvent; + + InfoData = HeapAlloc(GlobalHeap, 0, sizeof(*InfoData)); + if (!InfoData) + return SOCKET_ERROR; Status = NtCreateEvent(&SockEvent, EVENT_ALL_ACCESS, NULL, SynchronizationEvent, FALSE); - - if( !NT_SUCCESS(Status) ) + if (!NT_SUCCESS(Status)) + { + HeapFree(GlobalHeap, 0, InfoData); return SOCKET_ERROR; + } /* Set Info Class */ - InfoData.InformationClass = AfdInformationClass; + InfoData->InformationClass = AfdInformationClass; /* Set Information */ if (Ulong != NULL) { - InfoData.Information.Ulong = *Ulong; + InfoData->Information.Ulong = *Ulong; } if (LargeInteger != NULL) { - InfoData.Information.LargeInteger = *LargeInteger; + InfoData->Information.LargeInteger = *LargeInteger; } if (Boolean != NULL) { - InfoData.Information.Boolean = *Boolean; + InfoData->Information.Boolean = *Boolean; } /* Verify if we should use APC */ - if (Overlapped == NULL) + if (!Overlapped) { /* Not using Overlapped structure, so use normal blocking on event */ APCContext = NULL; @@ -3452,31 +3568,26 @@ SetSocketInformation(PSOCKET_INFORMATION Socket, if ((Socket->SharedData->CreateFlags & SO_SYNCHRONOUS_NONALERT) != 0) { TRACE("Opened without flag WSA_FLAG_OVERLAPPED. Do nothing.\n"); - NtClose( SockEvent ); - return 0; + NtClose(SockEvent); + HeapFree(GlobalHeap, 0, InfoData); + return MsafdReturnWithErrno(STATUS_SUCCESS, NULL, 0, NULL); } - if (CompletionRoutine == NULL) + + APCContext = HeapAlloc(GlobalHeap, 0, sizeof(*APCContext)); + if (!APCContext) { - /* Using Overlapped Structure, but no Completition Routine, so no need for APC */ - APCContext = (PAFDAPCCONTEXT)Overlapped; - APCFunction = NULL; + ERR("Not enough memory for APC Context\n"); + NtClose(SockEvent); + HeapFree(GlobalHeap, 0, InfoData); + return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, NULL, 0, NULL); + } + APCContext->lpCompletionRoutine = CompletionRoutine; + APCContext->lpOverlapped = Overlapped; + APCContext->lpInfoData = InfoData; + APCFunction = &MsafdInfoAPC; + + if (!CompletionRoutine) Event = Overlapped->hEvent; - } - else - { - /* Using Overlapped Structure and a Completition Routine, so use an APC */ - APCFunction = &AfdInfoAPC; // should be a private io completition function inside us - APCContext = HeapAlloc(GlobalHeap, 0, sizeof(AFDAPCCONTEXT)); - if (!APCContext) - { - ERR("Not enough memory for APC Context\n"); - NtClose( SockEvent ); - return WSAEFAULT; - } - APCContext->lpCompletionRoutine = CompletionRoutine; - APCContext->lpOverlapped = Overlapped; - APCContext->lpSocket = Socket; - } IOSB = (PIO_STATUS_BLOCK)&Overlapped->Internal; } @@ -3490,30 +3601,28 @@ SetSocketInformation(PSOCKET_INFORMATION Socket, APCContext, IOSB, IOCTL_AFD_SET_INFO, - &InfoData, - sizeof(InfoData), + InfoData, + sizeof(*InfoData), NULL, 0); - /* Wait for return */ - if (Status == STATUS_PENDING && Overlapped == NULL) + /* Wait for completion if not overlapped */ + if (!Overlapped && Status == STATUS_PENDING) { - WaitForSingleObject(SockEvent, INFINITE); + MsafdWaitForAlert(SockEvent); Status = IOSB->Status; } - NtClose( SockEvent ); + TRACE("Status %lx\n", Status); - TRACE("Status %x Information %d\n", Status, IOSB->Information); - - if (Status == STATUS_PENDING) + NtClose(SockEvent); + if (!APCFunction) { - TRACE("Leaving (Pending)\n"); - return WSA_IO_PENDING; + /* When using APC, this will be freed by the APC function */ + HeapFree(GlobalHeap, 0, InfoData); } - return Status == STATUS_SUCCESS ? NO_ERROR : SOCKET_ERROR; - + return MsafdReturnWithErrno(Status, NULL, 0, NULL); } PSOCKET_INFORMATION @@ -3581,7 +3690,7 @@ int CreateContext(PSOCKET_INFORMATION Socket) /* Wait for Completion */ if (Status == STATUS_PENDING) { - WaitForSingleObject(SockEvent, INFINITE); + MsafdWaitForAlert(SockEvent); Status = IOSB.Status; } diff --git a/dll/win32/msafd/misc/event.c b/dll/win32/msafd/misc/event.c index b3b0da2c743..06951cf85b5 100644 --- a/dll/win32/msafd/misc/event.c +++ b/dll/win32/msafd/misc/event.c @@ -106,7 +106,7 @@ WSPEventSelect( /* Wait for return */ if (Status == STATUS_PENDING) { - WaitForSingleObject(SockEvent, INFINITE); + MsafdWaitForAlert(SockEvent); Status = IOSB.Status; } @@ -181,7 +181,7 @@ WSPEnumNetworkEvents( /* Wait for return */ if (Status == STATUS_PENDING) { - WaitForSingleObject(SockEvent, INFINITE); + MsafdWaitForAlert(SockEvent); Status = IOSB.Status; } diff --git a/dll/win32/msafd/misc/sndrcv.c b/dll/win32/msafd/misc/sndrcv.c index 06b0afb1d51..3fc154201d7 100644 --- a/dll/win32/msafd/misc/sndrcv.c +++ b/dll/win32/msafd/misc/sndrcv.c @@ -102,6 +102,7 @@ WSPGetOverlappedResult( OUT LPINT lpErrno) { PSOCKET_INFORMATION Socket; + PIO_STATUS_BLOCK IoStatusBlock; BOOL Ret; TRACE("Called (%x)\n", Handle); @@ -120,11 +121,22 @@ WSPGetOverlappedResult( *lpErrno = WSAEFAULT; return FALSE; } + + IoStatusBlock = (PIO_STATUS_BLOCK)lpOverlapped; Ret = GetOverlappedResult((HANDLE)Handle, lpOverlapped, lpdwBytes, fWait); + /* HACK: Allow APC to be processed */ + SleepEx(0, TRUE); + + if (!fWait && IoStatusBlock->Status == STATUS_PENDING) + { + if (lpErrno) + *lpErrno = WSA_IO_INCOMPLETE; + return FALSE; + } + if (Ret) { - *lpdwFlags = 0; /* Re-enable Async Event */ SockReenableAsyncSelectEvent(Socket, FD_OOB); @@ -132,23 +144,75 @@ WSPGetOverlappedResult( SockReenableAsyncSelectEvent(Socket, FD_READ); } + switch (IoStatusBlock->Status) + { + case STATUS_RECEIVE_EXPEDITED: + *lpdwFlags = MSG_OOB; + break; + case STATUS_RECEIVE_PARTIAL_EXPEDITED: + *lpdwFlags = MSG_PARTIAL | MSG_OOB; + break; + case STATUS_RECEIVE_PARTIAL: + *lpdwFlags = MSG_PARTIAL; + break; + default: + *lpdwFlags = 0; + if (lpErrno) + *lpErrno = TranslateNtStatusError(IoStatusBlock->Status); + break; + } + return Ret; } +static VOID NTAPI -AfdAPC(PVOID ApcContext, - PIO_STATUS_BLOCK IoStatusBlock, - ULONG Reserved) +MsafdRecvAPC( + _In_ PVOID ApcContext, + _In_ PIO_STATUS_BLOCK IoStatusBlock, + _In_ ULONG Reserved) { - PAFDAPCCONTEXT Context = ApcContext; + PMSAFD_RECV_APC_CONTEXT Context = ApcContext; + + TRACE("MsafdRecvAPC(%p %lx %lx)\n", ApcContext, IoStatusBlock->Status, IoStatusBlock->Information); + + /* Re-enable Async Event */ + if (IoStatusBlock->Status == STATUS_RECEIVE_EXPEDITED || + IoStatusBlock->Status == STATUS_RECEIVE_PARTIAL_EXPEDITED) + { + SockReenableAsyncSelectEvent(Context->lpSocket, FD_OOB); + } + else + { + SockReenableAsyncSelectEvent(Context->lpSocket, FD_READ); + } + + if (Context->lpCompletionRoutine) + Context->lpCompletionRoutine(IoStatusBlock->Status, IoStatusBlock->Information, (LPWSAOVERLAPPED)IoStatusBlock, 0); + + HeapFree(GlobalHeap, 0, ApcContext); +} + +static +VOID +NTAPI +MsafdSendAPC( + _In_ PVOID ApcContext, + _In_ PIO_STATUS_BLOCK IoStatusBlock, + _In_ ULONG Reserved) +{ + PMSAFD_SEND_APC_CONTEXT Context = ApcContext; + + TRACE("MsafdSendAPC(%p %lx %lx)\n", ApcContext, IoStatusBlock->Status, IoStatusBlock->Information); /* Re-enable Async Event */ - SockReenableAsyncSelectEvent(Context->lpSocket, FD_OOB); - SockReenableAsyncSelectEvent(Context->lpSocket, FD_READ); SockReenableAsyncSelectEvent(Context->lpSocket, FD_WRITE); - Context->lpCompletionRoutine(IoStatusBlock->Status, IoStatusBlock->Information, Context->lpOverlapped, 0); + if (Context->lpCompletionRoutine) + Context->lpCompletionRoutine(IoStatusBlock->Status, IoStatusBlock->Information, (LPWSAOVERLAPPED)IoStatusBlock, 0); + + /* Free IOCTL buffers */ HeapFree(GlobalHeap, 0, ApcContext); } @@ -168,13 +232,14 @@ WSPRecv(SOCKET Handle, IO_STATUS_BLOCK DummyIOSB; AFD_RECV_INFO RecvInfo; NTSTATUS Status; - PVOID APCContext; - PIO_APC_ROUTINE APCFunction; + PMSAFD_RECV_APC_CONTEXT APCContext; + PIO_APC_ROUTINE APCFunction = NULL; HANDLE Event = NULL; HANDLE SockEvent; PSOCKET_INFORMATION Socket; + DWORD NumberOfBytesRead; - TRACE("Called (%x)\n", Handle); + TRACE("WSPRecv(%x)\n", Handle); /* Get the Socket Structure associate to this Socket*/ Socket = GetSocketStructure(Handle); @@ -197,11 +262,15 @@ WSPRecv(SOCKET Handle, return SOCKET_ERROR; } + memset(&RecvInfo, 0, sizeof(RecvInfo)); + Status = NtCreateEvent( &SockEvent, EVENT_ALL_ACCESS, NULL, SynchronizationEvent, FALSE ); - if( !NT_SUCCESS(Status) ) - return -1; + if (!NT_SUCCESS(Status)) + { + return SOCKET_ERROR; + } /* Set up the Receive Structure */ RecvInfo.BufferArray = (PAFD_WSABUF)lpBuffers; @@ -210,7 +279,7 @@ WSPRecv(SOCKET Handle, RecvInfo.AfdFlags = Socket->SharedData->NonBlocking ? AFD_IMMEDIATE : 0; /* Set the TDI Flags */ - if (*ReceiveFlags == 0) + if (!ReceiveFlags || *ReceiveFlags == 0) { RecvInfo.TdiFlags |= TDI_RECEIVE_NORMAL; } @@ -233,8 +302,7 @@ WSPRecv(SOCKET Handle, } /* Verify if we should use APC */ - - if (lpOverlapped == NULL) + if (!lpOverlapped) { /* Not using Overlapped structure, so use normal blocking on event */ APCContext = NULL; @@ -248,35 +316,48 @@ WSPRecv(SOCKET Handle, if ((Socket->SharedData->CreateFlags & SO_SYNCHRONOUS_NONALERT) != 0) { TRACE("Opened without flag WSA_FLAG_OVERLAPPED. Do nothing.\n"); - return MsafdReturnWithErrno(0, lpErrno, 0, lpNumberOfBytesRead); + NtClose(SockEvent); + return MsafdReturnWithErrno(STATUS_SUCCESS, lpErrno, 0, lpNumberOfBytesRead); } - if (lpCompletionRoutine == NULL) + + IOSB = (PIO_STATUS_BLOCK)lpOverlapped; + if (lpCompletionRoutine) { - /* Using Overlapped Structure, but no Completion Routine, so no need for APC */ - APCContext = lpOverlapped; - APCFunction = NULL; - Event = lpOverlapped->hEvent; - } - else - { - /* Using Overlapped Structure and a Completion Routine, so use an APC */ - APCFunction = &AfdAPC; // should be a private io completion function inside us - APCContext = HeapAlloc(GlobalHeap, 0, sizeof(AFDAPCCONTEXT)); + RecvInfo.AfdFlags |= AFD_SKIP_FIO; + APCFunction = &MsafdRecvAPC; + APCContext = HeapAlloc(GlobalHeap, 0, sizeof(*APCContext)); if (!APCContext) { ERR("Not enough memory for APC Context\n"); + NtClose(SockEvent); return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, lpNumberOfBytesRead); } - ((PAFDAPCCONTEXT)APCContext)->lpCompletionRoutine = lpCompletionRoutine; - ((PAFDAPCCONTEXT)APCContext)->lpOverlapped = lpOverlapped; - ((PAFDAPCCONTEXT)APCContext)->lpSocket = Socket; - RecvInfo.AfdFlags |= AFD_SKIP_FIO; + APCContext->lpCompletionRoutine = lpCompletionRoutine; + APCContext->lpOverlapped = lpOverlapped; + APCContext->lpSocket = Socket; + Event = NULL; + } + else + { + // .NET expects APCContext to be a pointer to lpOverlapped, because it adds some additional metadata past OVERLAPPED structure + // Do not use APC when there is no completion context + Event = lpOverlapped->hEvent; + + APCFunction = NULL; + if (Event) + { + APCContext = NULL; + } + else + { + APCContext = (PMSAFD_RECV_APC_CONTEXT)lpOverlapped; + } } - IOSB = (PIO_STATUS_BLOCK)&lpOverlapped->Internal; RecvInfo.AfdFlags |= AFD_OVERLAPPED; } + IOSB->Information = 0; IOSB->Status = STATUS_PENDING; /* Send IOCTL */ @@ -291,51 +372,47 @@ WSPRecv(SOCKET Handle, NULL, 0); - /* Non-blocking sockets must wait until data is available */ - if (Status == STATUS_PENDING && Socket->SharedData->NonBlocking) + /* Wait for completion if not overlapped */ + if (!lpOverlapped && Status == STATUS_PENDING) { - if (lpErrno) *lpErrno = WSAEWOULDBLOCK; - return SOCKET_ERROR; - } - - /* Wait for completion of not overlapped */ - if (Status == STATUS_PENDING && lpOverlapped == NULL) - { - /* It's up to the protocol to time out recv. We must wait - * until the protocol decides it's had enough. - */ - WaitForSingleObject(SockEvent, INFINITE); + /* It's up to the protocol to time out recv. We must wait + * until the protocol decides it's had enough. */ + MsafdWaitForAlert(SockEvent); Status = IOSB->Status; } - NtClose( SockEvent ); - - TRACE("Status %x Information %d\n", Status, IOSB->Information); + NtClose(SockEvent); if (Status == STATUS_PENDING) { TRACE("Leaving (Pending)\n"); - return MsafdReturnWithErrno(Status, lpErrno, IOSB->Information, lpNumberOfBytesRead); + return MsafdReturnWithErrno(Status, lpErrno, 0, NULL); } - /* Return the Flags */ - *ReceiveFlags = 0; + NumberOfBytesRead = (DWORD)IOSB->Information; - switch (Status) + /* Return the Flags */ + if (ReceiveFlags) { - case STATUS_RECEIVE_EXPEDITED: - *ReceiveFlags = MSG_OOB; - break; - case STATUS_RECEIVE_PARTIAL_EXPEDITED: - *ReceiveFlags = MSG_PARTIAL | MSG_OOB; - break; - case STATUS_RECEIVE_PARTIAL: - *ReceiveFlags = MSG_PARTIAL; - break; + switch (Status) + { + case STATUS_RECEIVE_EXPEDITED: + *ReceiveFlags = MSG_OOB; + break; + case STATUS_RECEIVE_PARTIAL_EXPEDITED: + *ReceiveFlags = MSG_PARTIAL | MSG_OOB; + break; + case STATUS_RECEIVE_PARTIAL: + *ReceiveFlags = MSG_PARTIAL; + break; + default: + *ReceiveFlags = 0; + break; + } } /* Re-enable Async Event */ - if (*ReceiveFlags & MSG_OOB) + if (ReceiveFlags && (*ReceiveFlags & MSG_OOB)) { SockReenableAsyncSelectEvent(Socket, FD_OOB); } @@ -344,13 +421,8 @@ WSPRecv(SOCKET Handle, SockReenableAsyncSelectEvent(Socket, FD_READ); } - if (Status == STATUS_SUCCESS && lpOverlapped && lpCompletionRoutine) - { - lpCompletionRoutine(Status, IOSB->Information, lpOverlapped, *ReceiveFlags); - HeapFree(GlobalHeap, 0, (PVOID)APCContext); - } - - return MsafdReturnWithErrno ( Status, lpErrno, IOSB->Information, lpNumberOfBytesRead ); + TRACE("Leaving (%lx %ld)\n", Status, NumberOfBytesRead); + return MsafdReturnWithErrno(Status, lpErrno, NumberOfBytesRead, lpNumberOfBytesRead); } int @@ -371,11 +443,14 @@ WSPRecvFrom(SOCKET Handle, IO_STATUS_BLOCK DummyIOSB; AFD_RECV_INFO_UDP RecvInfo; NTSTATUS Status; - PVOID APCContext; - PVOID APCFunction; + PMSAFD_RECV_APC_CONTEXT APCContext; + PIO_APC_ROUTINE APCFunction = NULL; HANDLE Event = NULL; HANDLE SockEvent; PSOCKET_INFORMATION Socket; + DWORD NumberOfBytesRead; + + TRACE("WSPRecvFrom(%x)\n", Handle); /* Get the Socket Structure associate to this Socket*/ Socket = GetSocketStructure(Handle); @@ -422,12 +497,15 @@ WSPRecvFrom(SOCKET Handle, if (WSPBind(Handle, SocketAddress, *SocketAddressLength, lpErrno) == SOCKET_ERROR) return SOCKET_ERROR; } + memset(&RecvInfo, 0, sizeof(RecvInfo)); Status = NtCreateEvent( &SockEvent, EVENT_ALL_ACCESS, NULL, SynchronizationEvent, FALSE ); if( !NT_SUCCESS(Status) ) - return -1; + { + return SOCKET_ERROR; + } /* Set up the Receive Structure */ RecvInfo.BufferArray = (PAFD_WSABUF)lpBuffers; @@ -438,7 +516,7 @@ WSPRecvFrom(SOCKET Handle, RecvInfo.Address = SocketAddress; /* Set the TDI Flags */ - if (*ReceiveFlags == 0) + if (!ReceiveFlags || *ReceiveFlags == 0) { RecvInfo.TdiFlags |= TDI_RECEIVE_NORMAL; } @@ -461,8 +539,7 @@ WSPRecvFrom(SOCKET Handle, } /* Verify if we should use APC */ - - if (lpOverlapped == NULL) + if (!lpOverlapped) { /* Not using Overlapped structure, so use normal blocking on event */ APCContext = NULL; @@ -476,35 +553,48 @@ WSPRecvFrom(SOCKET Handle, if ((Socket->SharedData->CreateFlags & SO_SYNCHRONOUS_NONALERT) != 0) { TRACE("Opened without flag WSA_FLAG_OVERLAPPED. Do nothing.\n"); - return MsafdReturnWithErrno(0, lpErrno, 0, lpNumberOfBytesRead); + NtClose(SockEvent); + return MsafdReturnWithErrno(STATUS_SUCCESS, lpErrno, 0, lpNumberOfBytesRead); } - if (lpCompletionRoutine == NULL) + + IOSB = (PIO_STATUS_BLOCK)lpOverlapped; + if (lpCompletionRoutine) { - /* Using Overlapped Structure, but no Completion Routine, so no need for APC */ - APCContext = lpOverlapped; - APCFunction = NULL; - Event = lpOverlapped->hEvent; - } - else - { - /* Using Overlapped Structure and a Completion Routine, so use an APC */ - APCFunction = &AfdAPC; // should be a private io completion function inside us - APCContext = HeapAlloc(GlobalHeap, 0, sizeof(AFDAPCCONTEXT)); + RecvInfo.AfdFlags |= AFD_SKIP_FIO; + APCFunction = &MsafdRecvAPC; + APCContext = HeapAlloc(GlobalHeap, 0, sizeof(*APCContext)); if (!APCContext) { ERR("Not enough memory for APC Context\n"); + NtClose(SockEvent); return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, lpNumberOfBytesRead); } - ((PAFDAPCCONTEXT)APCContext)->lpCompletionRoutine = lpCompletionRoutine; - ((PAFDAPCCONTEXT)APCContext)->lpOverlapped = lpOverlapped; - ((PAFDAPCCONTEXT)APCContext)->lpSocket = Socket; - RecvInfo.AfdFlags |= AFD_SKIP_FIO; + APCContext->lpCompletionRoutine = lpCompletionRoutine; + APCContext->lpOverlapped = lpOverlapped; + APCContext->lpSocket = Socket; + Event = NULL; + } + else + { + // .NET expects APCContext to be a pointer to lpOverlapped, because it adds some additional metadata past OVERLAPPED structure + // Do not use APC when there is no completion context + Event = lpOverlapped->hEvent; + + APCFunction = NULL; + if (Event) + { + APCContext = NULL; + } + else + { + APCContext = (PMSAFD_RECV_APC_CONTEXT)lpOverlapped; + } } - IOSB = (PIO_STATUS_BLOCK)&lpOverlapped->Internal; RecvInfo.AfdFlags |= AFD_OVERLAPPED; } + IOSB->Information = 0; IOSB->Status = STATUS_PENDING; /* Send IOCTL */ @@ -519,39 +609,47 @@ WSPRecvFrom(SOCKET Handle, NULL, 0); - /* Wait for completion of not overlapped */ - if (Status == STATUS_PENDING && lpOverlapped == NULL) + /* Wait for completion if not overlapped */ + if (!lpOverlapped && Status == STATUS_PENDING) { - WaitForSingleObject(SockEvent, INFINITE); // BUGBUG, shouldn wait infinitely for receive... + /* FIXME: Shouldn't wait infinitely for receive... */ + MsafdWaitForAlert(SockEvent); Status = IOSB->Status; } - NtClose( SockEvent ); + /* Return the Flags */ + if (ReceiveFlags) + { + switch (Status) + { + case STATUS_RECEIVE_EXPEDITED: + *ReceiveFlags = MSG_OOB; + break; + case STATUS_RECEIVE_PARTIAL_EXPEDITED: + *ReceiveFlags = MSG_PARTIAL | MSG_OOB; + break; + case STATUS_RECEIVE_PARTIAL: + *ReceiveFlags = MSG_PARTIAL; + break; + default: + *ReceiveFlags = 0; + break; + } + } + + NtClose(SockEvent); if (Status == STATUS_PENDING) { TRACE("Leaving (Pending)\n"); - return MsafdReturnWithErrno(Status, lpErrno, IOSB->Information, lpNumberOfBytesRead); + *lpErrno = WSA_IO_PENDING; + return MsafdReturnWithErrno(Status, lpErrno, 0, NULL); } - /* Return the Flags */ - *ReceiveFlags = 0; - - switch (Status) - { - case STATUS_RECEIVE_EXPEDITED: - *ReceiveFlags = MSG_OOB; - break; - case STATUS_RECEIVE_PARTIAL_EXPEDITED: - *ReceiveFlags = MSG_PARTIAL | MSG_OOB; - break; - case STATUS_RECEIVE_PARTIAL: - *ReceiveFlags = MSG_PARTIAL; - break; - } + NumberOfBytesRead = (DWORD)IOSB->Information; /* Re-enable Async Event */ - if (*ReceiveFlags & MSG_OOB) + if (ReceiveFlags && (*ReceiveFlags & MSG_OOB)) { SockReenableAsyncSelectEvent(Socket, FD_OOB); } @@ -560,13 +658,8 @@ WSPRecvFrom(SOCKET Handle, SockReenableAsyncSelectEvent(Socket, FD_READ); } - if (Status == STATUS_SUCCESS && lpOverlapped && lpCompletionRoutine) - { - lpCompletionRoutine(Status, IOSB->Information, lpOverlapped, *ReceiveFlags); - HeapFree(GlobalHeap, 0, (PVOID)APCContext); - } - - return MsafdReturnWithErrno ( Status, lpErrno, IOSB->Information, lpNumberOfBytesRead ); + TRACE("Leaving (%lx %ld)\n", Status, NumberOfBytesRead); + return MsafdReturnWithErrno(Status, lpErrno, NumberOfBytesRead, lpNumberOfBytesRead); } @@ -584,13 +677,16 @@ WSPSend(SOCKET Handle, { PIO_STATUS_BLOCK IOSB; IO_STATUS_BLOCK DummyIOSB; - AFD_SEND_INFO SendInfo; + AFD_SEND_INFO SendInfo; NTSTATUS Status; - PVOID APCContext; - PVOID APCFunction; + PMSAFD_SEND_APC_CONTEXT APCContext; + PIO_APC_ROUTINE APCFunction = NULL; HANDLE Event = NULL; HANDLE SockEvent; PSOCKET_INFORMATION Socket; + DWORD NumberOfBytesSent; + + TRACE("WSPSend(%x)\n", Handle); /* Get the Socket Structure associate to this Socket*/ Socket = GetSocketStructure(Handle); @@ -607,13 +703,15 @@ WSPSend(SOCKET Handle, return SOCKET_ERROR; } + memset(&SendInfo, 0, sizeof(SendInfo)); + Status = NtCreateEvent( &SockEvent, EVENT_ALL_ACCESS, NULL, SynchronizationEvent, FALSE ); if( !NT_SUCCESS(Status) ) - return -1; - - TRACE("Called\n"); + { + return SOCKET_ERROR; + } /* Set up the Send Structure */ SendInfo.BufferArray = (PAFD_WSABUF)lpBuffers; @@ -635,7 +733,7 @@ WSPSend(SOCKET Handle, } /* Verify if we should use APC */ - if (lpOverlapped == NULL) + if (!lpOverlapped) { /* Not using Overlapped structure, so use normal blocking on event */ APCContext = NULL; @@ -649,35 +747,58 @@ WSPSend(SOCKET Handle, if ((Socket->SharedData->CreateFlags & SO_SYNCHRONOUS_NONALERT) != 0) { TRACE("Opened without flag WSA_FLAG_OVERLAPPED. Do nothing.\n"); - return MsafdReturnWithErrno(0, lpErrno, 0, lpNumberOfBytesSent); + NtClose(SockEvent); + return MsafdReturnWithErrno(STATUS_SUCCESS, lpErrno, 0, lpNumberOfBytesSent); } - if (lpCompletionRoutine == NULL) + + IOSB = (PIO_STATUS_BLOCK)lpOverlapped; + if (lpCompletionRoutine) { - /* Using Overlapped Structure, but no Completion Routine, so no need for APC */ - APCContext = lpOverlapped; - APCFunction = NULL; - Event = lpOverlapped->hEvent; - } - else - { - /* Using Overlapped Structure and a Completion Routine, so use an APC */ - APCFunction = &AfdAPC; // should be a private io completion function inside us - APCContext = HeapAlloc(GlobalHeap, 0, sizeof(AFDAPCCONTEXT)); + SendInfo.AfdFlags |= AFD_SKIP_FIO; + APCFunction = &MsafdSendAPC; + APCContext = HeapAlloc(GlobalHeap, 0, sizeof(*APCContext)); if (!APCContext) { ERR("Not enough memory for APC Context\n"); + NtClose(SockEvent); return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, lpNumberOfBytesSent); } - ((PAFDAPCCONTEXT)APCContext)->lpCompletionRoutine = lpCompletionRoutine; - ((PAFDAPCCONTEXT)APCContext)->lpOverlapped = lpOverlapped; - ((PAFDAPCCONTEXT)APCContext)->lpSocket = Socket; - SendInfo.AfdFlags |= AFD_SKIP_FIO; + APCContext->lpCompletionRoutine = lpCompletionRoutine; + APCContext->lpOverlapped = lpOverlapped; + APCContext->lpSocket = Socket; + APCContext->lpRemoteAddress = NULL; + Event = NULL; + } + else + { + // .NET expects APCContext to be a pointer to lpOverlapped, because it adds some additional metadata past OVERLAPPED structure + // Do not use APC when there is no completion context + Event = lpOverlapped->hEvent; + + APCFunction = NULL; + if (Event) + { + APCContext = NULL; + } + else + { + APCContext = (PMSAFD_SEND_APC_CONTEXT)lpOverlapped; + } + } + + Event = lpOverlapped->hEvent; + // APCContext is unsupported when using event + if (Event) + { + if (lpCompletionRoutine) + HeapFree(GetProcessHeap(), 0, APCContext); + APCContext = NULL; } - IOSB = (PIO_STATUS_BLOCK)&lpOverlapped->Internal; SendInfo.AfdFlags |= AFD_OVERLAPPED; } + IOSB->Information = 0; IOSB->Status = STATUS_PENDING; /* Send IOCTL */ @@ -692,33 +813,29 @@ WSPSend(SOCKET Handle, NULL, 0); - /* Wait for completion of not overlapped */ - if (Status == STATUS_PENDING && lpOverlapped == NULL) + /* Wait for completion if not overlapped */ + if (!lpOverlapped && Status == STATUS_PENDING) { - WaitForSingleObject(SockEvent, INFINITE); // BUGBUG, shouldn wait infinitely for send... + /* FIXME: Shouldn't wait infinitely for send... */ + MsafdWaitForAlert(SockEvent); Status = IOSB->Status; } - NtClose( SockEvent ); + NtClose(SockEvent); if (Status == STATUS_PENDING) { TRACE("Leaving (Pending)\n"); - return MsafdReturnWithErrno(Status, lpErrno, IOSB->Information, lpNumberOfBytesSent); + return MsafdReturnWithErrno(Status, lpErrno, 0, NULL); } + NumberOfBytesSent = (DWORD)IOSB->Information; + /* Re-enable Async Event */ SockReenableAsyncSelectEvent(Socket, FD_WRITE); - TRACE("Leaving (Success, %d)\n", IOSB->Information); - - if (Status == STATUS_SUCCESS && lpOverlapped && lpCompletionRoutine) - { - lpCompletionRoutine(Status, IOSB->Information, lpOverlapped, 0); - HeapFree(GlobalHeap, 0, (PVOID)APCContext); - } - - return MsafdReturnWithErrno( Status, lpErrno, IOSB->Information, lpNumberOfBytesSent ); + TRACE("Leaving (%lx %ld)\n", Status, NumberOfBytesSent); + return MsafdReturnWithErrno(Status, lpErrno, NumberOfBytesSent, lpNumberOfBytesSent); } int @@ -739,14 +856,16 @@ WSPSendTo(SOCKET Handle, IO_STATUS_BLOCK DummyIOSB; AFD_SEND_INFO_UDP SendInfo; NTSTATUS Status; - PVOID APCContext; - PVOID APCFunction; + PMSAFD_SEND_APC_CONTEXT APCContext; + PIO_APC_ROUTINE APCFunction = NULL; HANDLE Event = NULL; PTRANSPORT_ADDRESS RemoteAddress; - PSOCKADDR BindAddress = NULL; - INT BindAddressLength; HANDLE SockEvent; PSOCKET_INFORMATION Socket; + DWORD NumberOfBytesSent; + UCHAR Buffer[128]; + + TRACE("WSPSendTo(%x)\n", Handle); /* Get the Socket Structure associate to this Socket */ Socket = GetSocketStructure(Handle); @@ -780,6 +899,10 @@ WSPSendTo(SOCKET Handle, /* Bind us First */ if (Socket->SharedData->State == SocketOpen) { + INT BindAddressLength; + PSOCKADDR BindAddress; + INT BindError; + /* Get the Wildcard Address */ BindAddressLength = Socket->HelperData->MaxWSAddressLength; BindAddress = HeapAlloc(GlobalHeap, 0, BindAddressLength); @@ -793,35 +916,29 @@ WSPSendTo(SOCKET Handle, BindAddress, &BindAddressLength); /* Bind it */ - if (WSPBind(Handle, BindAddress, BindAddressLength, lpErrno) == SOCKET_ERROR) + BindError = WSPBind(Handle, BindAddress, BindAddressLength, lpErrno); + HeapFree(GlobalHeap, 0, BindAddress); + if (BindError == SOCKET_ERROR) return SOCKET_ERROR; } - RemoteAddress = HeapAlloc(GlobalHeap, 0, 0x6 + SocketAddressLength); - if (!RemoteAddress) - { - if (BindAddress != NULL) - { - HeapFree(GlobalHeap, 0, BindAddress); - } - return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL); - } - Status = NtCreateEvent(&SockEvent, - EVENT_ALL_ACCESS, - NULL, SynchronizationEvent, FALSE); + EVENT_ALL_ACCESS, + NULL, SynchronizationEvent, FALSE); if (!NT_SUCCESS(Status)) { - HeapFree(GlobalHeap, 0, RemoteAddress); - if (BindAddress != NULL) - { - HeapFree(GlobalHeap, 0, BindAddress); - } return SOCKET_ERROR; } - /* Set up Address in TDI Format */ + RemoteAddress = (PTRANSPORT_ADDRESS)Buffer; + + if (SocketAddressLength > 128 - sizeof(TRANSPORT_ADDRESS)) { + *lpErrno = WSAEFAULT; + return SOCKET_ERROR; + } + + /* Fill TDI address */ RemoteAddress->TAAddressCount = 1; RemoteAddress->Address[0].AddressLength = SocketAddressLength - sizeof(SocketAddress->sa_family); RtlCopyMemory(&RemoteAddress->Address[0].AddressType, SocketAddress, SocketAddressLength); @@ -834,7 +951,7 @@ WSPSendTo(SOCKET Handle, SendInfo.TdiConnection.RemoteAddressLength = Socket->HelperData->MaxTDIAddressLength; /* Verify if we should use APC */ - if (lpOverlapped == NULL) + if (!lpOverlapped) { /* Not using Overlapped structure, so use normal blocking on event */ APCContext = NULL; @@ -848,35 +965,51 @@ WSPSendTo(SOCKET Handle, if ((Socket->SharedData->CreateFlags & SO_SYNCHRONOUS_NONALERT) != 0) { TRACE("Opened without flag WSA_FLAG_OVERLAPPED. Do nothing.\n"); - return MsafdReturnWithErrno(0, lpErrno, 0, lpNumberOfBytesSent); + NtClose(SockEvent); + return MsafdReturnWithErrno(STATUS_SUCCESS, lpErrno, 0, lpNumberOfBytesSent); } - if (lpCompletionRoutine == NULL) + + IOSB = (PIO_STATUS_BLOCK)lpOverlapped; + if (lpCompletionRoutine) { - /* Using Overlapped Structure, but no Completion Routine, so no need for APC */ - APCContext = lpOverlapped; - APCFunction = NULL; - Event = lpOverlapped->hEvent; - } - else - { - /* Using Overlapped Structure and a Completion Routine, so use an APC */ - APCFunction = &AfdAPC; // should be a private io completion function inside us - APCContext = HeapAlloc(GlobalHeap, 0, sizeof(AFDAPCCONTEXT)); + SendInfo.AfdFlags |= AFD_SKIP_FIO; + APCFunction = &MsafdSendAPC; + APCContext = HeapAlloc(GlobalHeap, 0, sizeof(*APCContext)); if (!APCContext) { ERR("Not enough memory for APC Context\n"); + NtClose(SockEvent); return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, lpNumberOfBytesSent); } - ((PAFDAPCCONTEXT)APCContext)->lpCompletionRoutine = lpCompletionRoutine; - ((PAFDAPCCONTEXT)APCContext)->lpOverlapped = lpOverlapped; - ((PAFDAPCCONTEXT)APCContext)->lpSocket = Socket; - SendInfo.AfdFlags |= AFD_SKIP_FIO; + APCContext->lpCompletionRoutine = lpCompletionRoutine; + APCContext->lpOverlapped = lpOverlapped; + APCContext->lpSocket = Socket; + APCContext->lpRemoteAddress = RemoteAddress; + Event = NULL; + } + else + { + // .NET expects APCContext to be a pointer to lpOverlapped, because it adds some additional metadata past OVERLAPPED structure + // Do not use APC when there is no completion context + Event = lpOverlapped->hEvent; + + APCFunction = NULL; + if (Event) + { + APCContext = NULL; + } + else + { + APCContext = (PMSAFD_SEND_APC_CONTEXT)lpOverlapped; + } } - IOSB = (PIO_STATUS_BLOCK)&lpOverlapped->Internal; SendInfo.AfdFlags |= AFD_OVERLAPPED; } + IOSB->Information = 0; + IOSB->Status = STATUS_PENDING; + /* Send IOCTL */ Status = NtDeviceIoControlFile((HANDLE)Handle, Event, @@ -889,36 +1022,29 @@ WSPSendTo(SOCKET Handle, NULL, 0); - /* Wait for completion of not overlapped */ - if (Status == STATUS_PENDING && lpOverlapped == NULL) + /* Wait for completion if not overlapped */ + if (!lpOverlapped && Status == STATUS_PENDING) { - /* BUGBUG, shouldn't wait infinitely for send... */ - WaitForSingleObject(SockEvent, INFINITE); + /* FIXME: Shouldn't wait infinitely for send... */ + MsafdWaitForAlert(SockEvent); Status = IOSB->Status; } NtClose(SockEvent); - HeapFree(GlobalHeap, 0, RemoteAddress); - if (BindAddress != NULL) - { - HeapFree(GlobalHeap, 0, BindAddress); - } if (Status == STATUS_PENDING) { TRACE("Leaving (Pending)\n"); - return MsafdReturnWithErrno(Status, lpErrno, IOSB->Information, lpNumberOfBytesSent); + return MsafdReturnWithErrno(Status, lpErrno, 0, NULL); } + NumberOfBytesSent = (DWORD)IOSB->Information; + + /* Re-enable Async Event */ SockReenableAsyncSelectEvent(Socket, FD_WRITE); - if (Status == STATUS_SUCCESS && lpOverlapped && lpCompletionRoutine) - { - lpCompletionRoutine(Status, IOSB->Information, lpOverlapped, 0); - HeapFree(GlobalHeap, 0, (PVOID)APCContext); - } - - return MsafdReturnWithErrno(Status, lpErrno, IOSB->Information, lpNumberOfBytesSent); + TRACE("Leaving (%lx %ld)\n", Status, NumberOfBytesSent); + return MsafdReturnWithErrno(Status, lpErrno, NumberOfBytesSent, lpNumberOfBytesSent); } INT diff --git a/dll/win32/msafd/msafd.h b/dll/win32/msafd/msafd.h index 87b0125cc4c..685846e5c59 100644 --- a/dll/win32/msafd/msafd.h +++ b/dll/win32/msafd/msafd.h @@ -133,12 +133,34 @@ typedef struct _ASYNC_DATA { AFD_POLL_INFO AsyncSelectInfo; } ASYNC_DATA, *PASYNC_DATA; -typedef struct _AFDAPCCONTEXT +typedef struct _MSAFD_INFO_APC_CONTEXT +{ + LPWSAOVERLAPPED lpOverlapped; + LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine; + PAFD_INFO lpInfoData; +} MSAFD_INFO_APC_CONTEXT, *PMSAFD_INFO_APC_CONTEXT; + +typedef struct _MSAFD_SEND_APC_CONTEXT { LPWSAOVERLAPPED lpOverlapped; LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine; PSOCKET_INFORMATION lpSocket; -} AFDAPCCONTEXT, *PAFDAPCCONTEXT; + PTRANSPORT_ADDRESS lpRemoteAddress; +} MSAFD_SEND_APC_CONTEXT, *PMSAFD_SEND_APC_CONTEXT; + +typedef struct _MSAFD_RECV_APC_CONTEXT +{ + LPWSAOVERLAPPED lpOverlapped; + LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine; + PSOCKET_INFORMATION lpSocket; +} MSAFD_RECV_APC_CONTEXT, *PMSAFD_RECV_APC_CONTEXT; + +typedef struct _MSAFD_CONNECT_APC_CONTEXT +{ + SOCKET lpSocket; + IO_STATUS_BLOCK IoStatusBlock; +} MSAFD_CONNECT_APC_CONTEXT, *PMSAFD_CONNECT_APC_CONTEXT; + _Must_inspect_result_ SOCKET @@ -543,10 +565,11 @@ typedef VOID (*PASYNC_COMPLETION_ROUTINE)(PVOID Context, PIO_STATUS_BLOCK IoStat FORCEINLINE DWORD -MsafdReturnWithErrno(NTSTATUS Status, - LPINT Errno, - DWORD Received, - LPDWORD ReturnedBytes) +MsafdReturnWithErrno( + _In_ NTSTATUS Status, + _Out_opt_ LPINT Errno, + _In_ DWORD Received, + _Out_opt_ LPDWORD ReturnedBytes) { if (Errno) { @@ -559,8 +582,6 @@ MsafdReturnWithErrno(NTSTATUS Status, } else { - DbgPrint("%s: Received invalid lpErrno pointer!\n", __FUNCTION__); - if (ReturnedBytes) *ReturnedBytes = (Status == STATUS_SUCCESS) ? Received : 0; @@ -568,4 +589,8 @@ MsafdReturnWithErrno(NTSTATUS Status, } } +VOID +MsafdWaitForAlert( + _In_ HANDLE hObject); + #endif /* __MSAFD_H */ diff --git a/drivers/network/afd/afd/write.c b/drivers/network/afd/afd/write.c index 29ad9501b73..24a75a99a44 100644 --- a/drivers/network/afd/afd/write.c +++ b/drivers/network/afd/afd/write.c @@ -379,6 +379,10 @@ AfdConnectedSocketWriteData(PDEVICE_OBJECT DeviceObject, PIRP Irp, Status = QueueUserModeIrp(FCB, Irp, FUNCTION_SEND); if (Status == STATUS_PENDING) { + if (SendReq->BufferCount > 1) + { + AFD_DbgPrint(MIN_TRACE,("WARN: More than one buffer %ld\n", SendReq->BufferCount)); + } Status = TdiSendDatagram(&FCB->SendIrp.InFlightRequest, FCB->AddressFile.Object, SendReq->BufferArray[0].buf, @@ -390,7 +394,7 @@ AfdConnectedSocketWriteData(PDEVICE_OBJECT DeviceObject, PIRP Irp, { NT_VERIFY(RemoveHeadList(&FCB->PendingIrpList[FUNCTION_SEND]) == &Irp->Tail.Overlay.ListEntry); Irp->IoStatus.Status = Status; - Irp->IoStatus.Information = 0; + Irp->IoStatus.Information = ((Status == STATUS_SUCCESS) ? SendReq->BufferArray[0].len : 0); (void)IoSetCancelRoutine(Irp, NULL); UnlockBuffers(SendReq->BufferArray, SendReq->BufferCount, FALSE); UnlockRequest(Irp, IoGetCurrentIrpStackLocation(Irp)); @@ -645,6 +649,10 @@ AfdPacketSocketWriteData(PDEVICE_OBJECT DeviceObject, PIRP Irp, Status = QueueUserModeIrp(FCB, Irp, FUNCTION_SEND); if (Status == STATUS_PENDING) { + if (SendReq->BufferCount > 1) + { + AFD_DbgPrint(MIN_TRACE,("WARN: More than one buffer %ld\n", SendReq->BufferCount)); + } Status = TdiSendDatagram(&FCB->SendIrp.InFlightRequest, FCB->AddressFile.Object, SendReq->BufferArray[0].buf, @@ -656,7 +664,7 @@ AfdPacketSocketWriteData(PDEVICE_OBJECT DeviceObject, PIRP Irp, { NT_VERIFY(RemoveHeadList(&FCB->PendingIrpList[FUNCTION_SEND]) == &Irp->Tail.Overlay.ListEntry); Irp->IoStatus.Status = Status; - Irp->IoStatus.Information = 0; + Irp->IoStatus.Information = ((Status == STATUS_SUCCESS) ? SendReq->BufferArray[0].len : 0); (void)IoSetCancelRoutine(Irp, NULL); UnlockBuffers(SendReq->BufferArray, SendReq->BufferCount, FALSE); UnlockRequest(Irp, IoGetCurrentIrpStackLocation(Irp)); diff --git a/drivers/network/tdihelpers/tdi.c b/drivers/network/tdihelpers/tdi.c index e4019b01ab4..fdb83c3f4bd 100644 --- a/drivers/network/tdihelpers/tdi.c +++ b/drivers/network/tdihelpers/tdi.c @@ -989,11 +989,9 @@ NTSTATUS TdiSend( Flags, BufferLength); - TdiCall(*Irp, DeviceObject, NULL, NULL); /* Does not block... The MDL is deleted in the receive completion routine. */ - - return STATUS_PENDING; + return TdiCall(*Irp, DeviceObject, NULL, NULL); } NTSTATUS TdiReceive( @@ -1256,11 +1254,9 @@ NTSTATUS TdiSendDatagram( BufferLength, Addr); - TdiCall(*Irp, DeviceObject, NULL, NULL); /* Does not block... The MDL is deleted in the send completion routine. */ - - return STATUS_PENDING; + return TdiCall(*Irp, DeviceObject, NULL, NULL); } NTSTATUS TdiDisconnect(