diff --git a/dll/win32/userenv/profile.c b/dll/win32/userenv/profile.c index 6b4635ba062..1c18a67ab67 100644 --- a/dll/win32/userenv/profile.c +++ b/dll/win32/userenv/profile.c @@ -197,6 +197,157 @@ CreateProfileMutex( } +static +DWORD +IncrementRefCount( + PWSTR pszSidString, + PDWORD pdwRefCount) +{ + HKEY hProfilesKey = NULL, hProfileKey = NULL; + DWORD dwRefCount = 0, dwLength, dwType; + DWORD dwError; + + DPRINT1("IncrementRefCount(%S %p)\n", + pszSidString, pdwRefCount); + + dwError = RegOpenKeyExW(HKEY_LOCAL_MACHINE, + L"SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\ProfileList", + 0, + KEY_QUERY_VALUE, + &hProfilesKey); + if (dwError != ERROR_SUCCESS) + { + DPRINT1("Error: %lu\n", dwError); + goto done; + } + + dwError = RegOpenKeyExW(hProfilesKey, + pszSidString, + 0, + KEY_QUERY_VALUE | KEY_SET_VALUE, + &hProfileKey); + if (dwError != ERROR_SUCCESS) + { + DPRINT1("Error: %lu\n", dwError); + goto done; + } + + /* Get the reference counter */ + dwLength = sizeof(dwRefCount); + RegQueryValueExW(hProfileKey, + L"RefCount", + NULL, + &dwType, + (PBYTE)&dwRefCount, + &dwLength); + + dwRefCount++; + + dwLength = sizeof(dwRefCount); + dwError = RegSetValueExW(hProfileKey, + L"RefCount", + 0, + REG_DWORD, + (PBYTE)&dwRefCount, + dwLength); + if (dwError != ERROR_SUCCESS) + { + DPRINT1("Error: %lu\n", dwError); + goto done; + } + + if (pdwRefCount != NULL) + *pdwRefCount = dwRefCount; + +done: + if (hProfileKey != NULL) + RegCloseKey(hProfileKey); + + if (hProfilesKey != NULL) + RegCloseKey(hProfilesKey); + + return dwError; +} + + +static +DWORD +DecrementRefCount( + PWSTR pszSidString, + PDWORD pdwRefCount) +{ + HKEY hProfilesKey = NULL, hProfileKey = NULL; + DWORD dwRefCount = 0, dwLength, dwType; + DWORD dwError; + + DPRINT1("DecrementRefCount(%S %p)\n", + pszSidString, pdwRefCount); + + dwError = RegOpenKeyExW(HKEY_LOCAL_MACHINE, + L"SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\ProfileList", + 0, + KEY_QUERY_VALUE, + &hProfilesKey); + if (dwError != ERROR_SUCCESS) + { + DPRINT1("Error: %lu\n", dwError); + goto done; + } + + dwError = RegOpenKeyExW(hProfilesKey, + pszSidString, + 0, + KEY_QUERY_VALUE | KEY_SET_VALUE, + &hProfileKey); + if (dwError != ERROR_SUCCESS) + { + DPRINT1("Error: %lu\n", dwError); + goto done; + } + + /* Get the reference counter */ + dwLength = sizeof(dwRefCount); + dwError = RegQueryValueExW(hProfileKey, + L"RefCount", + NULL, + &dwType, + (PBYTE)&dwRefCount, + &dwLength); + if (dwError != ERROR_SUCCESS) + { + DPRINT1("Error: %lu\n", dwError); + goto done; + } + + dwRefCount--; + + dwLength = sizeof(dwRefCount); + dwError = RegSetValueExW(hProfileKey, + L"RefCount", + 0, + REG_DWORD, + (PBYTE)&dwRefCount, + dwLength); + if (dwError != ERROR_SUCCESS) + { + DPRINT1("Error: %lu\n", dwError); + goto done; + } + + if (pdwRefCount != NULL) + *pdwRefCount = dwRefCount; + +done: + if (hProfileKey != NULL) + RegCloseKey(hProfileKey); + + if (hProfilesKey != NULL) + RegCloseKey(hProfilesKey); + + return dwError; +} + + /* PUBLIC FUNCTIONS ********************************************************/ BOOL @@ -1613,6 +1764,14 @@ LoadUserProfileW( goto cleanup; } + Error = IncrementRefCount(SidString.Buffer, NULL); + if (Error != ERROR_SUCCESS) + { + DPRINT1("IncrementRefCount() failed (Error %ld)\n", Error); + SetLastError((DWORD)Error); + goto cleanup; + } + ret = TRUE; cleanup: @@ -1640,6 +1799,7 @@ UnloadUserProfile( { UNICODE_STRING SidString = {0, 0, NULL}; HANDLE hProfileMutex = NULL; + DWORD dwRefCount = 0; LONG Error; BOOL bRet = FALSE; @@ -1675,46 +1835,59 @@ UnloadUserProfile( /* Close the profile handle */ RegCloseKey(hProfile); - /* Acquire restore privilege */ - if (!AcquireRemoveRestorePrivilege(TRUE)) - { - DPRINT1("AcquireRemoveRestorePrivilege() failed (Error %ld)\n", GetLastError()); - goto cleanup; - } - - /* HACK */ - { - HKEY hUserKey; - - Error = RegOpenKeyExW(HKEY_USERS, - SidString.Buffer, - 0, - KEY_WRITE, - &hUserKey); - if (Error == ERROR_SUCCESS) - { - RegDeleteKeyW(hUserKey, - L"Volatile Environment"); - - RegCloseKey(hUserKey); - } - } - /* End of HACK */ - - /* Unload the hive */ - Error = RegUnLoadKeyW(HKEY_USERS, - SidString.Buffer); - - /* Remove restore privilege */ - AcquireRemoveRestorePrivilege(FALSE); - + Error = DecrementRefCount(SidString.Buffer, &dwRefCount); if (Error != ERROR_SUCCESS) { - DPRINT1("RegUnLoadKeyW() failed (Error %ld)\n", Error); + DPRINT1("DecrementRefCount() failed (Error %ld)\n", Error); SetLastError((DWORD)Error); goto cleanup; } + if (dwRefCount == 0) + { + DPRINT1("RefCount is 0: Unload the Hive!\n"); + + /* Acquire restore privilege */ + if (!AcquireRemoveRestorePrivilege(TRUE)) + { + DPRINT1("AcquireRemoveRestorePrivilege() failed (Error %ld)\n", GetLastError()); + goto cleanup; + } + + /* HACK */ + { + HKEY hUserKey; + + Error = RegOpenKeyExW(HKEY_USERS, + SidString.Buffer, + 0, + KEY_WRITE, + &hUserKey); + if (Error == ERROR_SUCCESS) + { + RegDeleteKeyW(hUserKey, + L"Volatile Environment"); + + RegCloseKey(hUserKey); + } + } + /* End of HACK */ + + /* Unload the hive */ + Error = RegUnLoadKeyW(HKEY_USERS, + SidString.Buffer); + + /* Remove restore privilege */ + AcquireRemoveRestorePrivilege(FALSE); + + if (Error != ERROR_SUCCESS) + { + DPRINT1("RegUnLoadKeyW() failed (Error %ld)\n", Error); + SetLastError((DWORD)Error); + goto cleanup; + } + } + bRet = TRUE; cleanup: