first init

This commit is contained in:
jiangcuo 2025-04-14 10:19:00 +08:00
commit 75fd783bc0
4 changed files with 1037 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/target

20
Cargo.toml Normal file
View File

@ -0,0 +1,20 @@
[package]
name = "pxvdi-hyperv-agent"
version = "0.1.0"
edition = "2021"
[dependencies]
actix-web = { version = "4.4.0", features = ["rustls"] }
tokio = { version = "1.32.0", features = ["full"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
serde_yaml = "0.9"
uuid = { version = "1.4.1", features = ["v4"] }
anyhow = "1.0.75"
rustls = "0.20.9"
rustls-pemfile = "0.2.1"
rand = "0.8.5"
windows = { version = "0.48", features = ["Win32_Foundation", "Win32_Security", "Win32_System_Threading", "Win32_UI_WindowsAndMessaging", "Win32_System_ProcessStatus", "Win32_System_Diagnostics_Debug"] }
wmi = "0.15.2"
regex = "1.11.1"
chrono = "0.4.40"

87
README.md Normal file
View File

@ -0,0 +1,87 @@
# PXVDI Hyper-V Agent
这是一个用Rust编写的Hyper-V管理代理提供REST API接口来管理Hyper-V虚拟机。
## 功能特性
- 列出所有虚拟机
- 获取虚拟机详细信息
- 启动/停止虚拟机
- 获取虚拟机网络信息
- 获取虚拟机快照信息
## 系统要求
- Windows操作系统
- 已安装Hyper-V
- Rust开发环境
## 安装
1. 确保已安装Rust开发环境
2. 克隆此仓库
3. 在项目目录中运行:
```bash
cargo build --release
```
## 运行
```bash
cargo run --release
```
服务器将在 http://localhost:3000 启动
## API接口
所有API请求都需要在header中包含 `Key: hyperv`
### 获取虚拟机列表
```
GET /api2/ListVM
```
### 获取虚拟机详细信息
```
GET /api2/ListVMDetils?VMID=<虚拟机ID>
```
### 获取所有虚拟机信息
```
GET /api2/GetALL
```
### 获取网络信息
```
GET /api2/GetNetWork
GET /api2/GetNetWork?VMID=<虚拟机ID>
```
### 获取快照信息
```
GET /api2/GetSnapShot
GET /api2/GetSnapShot?VMID=<虚拟机ID>
```
### 停止虚拟机
```
POST /api2/StopVM?VMID=<虚拟机ID>
```
### 启动虚拟机
```
POST /api2/StartVM?VMID=<虚拟机ID>
```
## 错误处理
所有API都会返回适当的HTTP状态码和JSON格式的错误信息
- 400: 请求参数错误
- 401: 认证错误
- 500: 服务器内部错误
## 许可证
MIT

929
src/main.rs Normal file
View File

@ -0,0 +1,929 @@
use actix_web::{web, App, HttpResponse, HttpServer, Responder};
use serde::{Deserialize, Serialize};
use std::process::Command;
use uuid::Uuid;
use anyhow::Result;
use std::fs;
use std::path::Path;
use rustls::{Certificate, PrivateKey, ServerConfig};
use rustls_pemfile::{certs, pkcs8_private_keys};
use std::fs::File;
use std::io::BufReader;
use rand::{Rng, distributions::Alphanumeric};
use windows::{
Win32::Foundation::*,
Win32::Security::*,
Win32::System::Threading::*,
};
use std::env;
use std::os::windows::process::CommandExt;
use wmi::{COMLibrary, WMIConnection, Variant, FilterValue};
use std::collections::HashMap;
use regex;
fn is_admin() -> bool {
unsafe {
let mut token_handle = HANDLE::default();
if !OpenProcessToken(GetCurrentProcess(), TOKEN_QUERY, &mut token_handle).as_bool() {
return false;
}
let mut elevation = TOKEN_ELEVATION::default();
let mut size = std::mem::size_of::<TOKEN_ELEVATION>() as u32;
let result = GetTokenInformation(
token_handle,
TokenElevation,
Some(&mut elevation as *mut _ as *mut _),
size,
&mut size,
);
CloseHandle(token_handle);
result.as_bool() && elevation.TokenIsElevated != 0
}
}
fn elevate() -> Result<()> {
let exe_path = env::current_exe()?;
let exe_path_str = exe_path.to_string_lossy();
// 使用Command启动进程并请求管理员权限
let mut cmd = Command::new("powershell");
cmd.arg("-Command")
.arg(format!("Start-Process -FilePath '{}' -Verb RunAs", exe_path_str))
.creation_flags(0x08000000); // CREATE_NO_WINDOW
let status = cmd.status()?;
if !status.success() {
anyhow::bail!("提权失败,退出码: {:?}", status.code());
}
std::process::exit(0);
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct Config {
key: String,
port: u16,
host: String,
tls: bool,
cert_path: Option<String>,
key_path: Option<String>,
}
impl Default for Config {
fn default() -> Self {
Config {
key: "Dsda12hg843".to_string(),
port: 9654,
host: "127.0.0.1".to_string(),
tls: false,
cert_path: None,
key_path: None,
}
}
}
fn load_config() -> Result<Config> {
let config_path = Path::new("config.yml");
if !config_path.exists() {
// 如果配置文件不存在,创建默认配置
let default_config = Config::default();
let yaml = serde_yaml::to_string(&default_config)?;
fs::write(config_path, yaml)?;
return Ok(default_config);
}
// 读取现有配置
let contents = fs::read_to_string(config_path)?;
let config: Config = serde_yaml::from_str(&contents)?;
Ok(config)
}
fn load_rustls_config(cert_path: &str, key_path: &str) -> Result<ServerConfig> {
// 加载证书
let cert_file = File::open(cert_path)?;
let mut cert_reader = BufReader::new(cert_file);
let cert_chain = certs(&mut cert_reader)?
.into_iter()
.map(Certificate)
.collect();
// 加载私钥
let key_file = File::open(key_path)?;
let mut key_reader = BufReader::new(key_file);
let mut keys = pkcs8_private_keys(&mut key_reader)?;
if keys.is_empty() {
anyhow::bail!("No private keys found");
}
let key = PrivateKey(keys.remove(0));
// 创建TLS配置
let config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(cert_chain, key)?;
Ok(config)
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct VirtualMachine {
vmid: String,
name: String,
state: String,
status: String,
cpu_cores: u32,
memory_mb: u64,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct IpInfo {
mac_address: String,
switch_name: String,
ip_addresses: Vec<String>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct NetworkAdapter {
macaddress: String,
vmid: String,
switchname: String,
ipaddresses: Vec<String>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct Snapshot {
name: String,
creation_time: String,
}
// 获取WMI连接
fn get_wmi_connection() -> Result<WMIConnection> {
let com_lib = COMLibrary::new()?;
let wmi_con = WMIConnection::with_namespace_path("root\\virtualization\\v2", com_lib)?;
Ok(wmi_con)
}
// 将Variant转换为String的辅助函数
fn variant_to_string(v: &Variant) -> Option<String> {
match v {
Variant::String(s) => Some(s.clone()),
Variant::I2(i) => Some(i.to_string()),
Variant::UI2(i) => Some(i.to_string()),
Variant::I4(i) => Some(i.to_string()),
Variant::UI4(i) => Some(i.to_string()),
Variant::R4(f) => Some(f.to_string()),
Variant::R8(f) => Some(f.to_string()),
Variant::Bool(b) => Some(b.to_string()),
_ => None,
}
}
async fn stop_vm(vmid: &str) -> Result<String> {
let wmi_con = get_wmi_connection()?;
// 查询虚拟机
let mut filters = HashMap::new();
filters.insert("Name".to_string(), FilterValue::String(vmid.to_string()));
let vms: Vec<HashMap<String, Variant>> = wmi_con.filtered_query(&filters)?;
if vms.is_empty() {
return Ok("VM not found".to_string());
}
// 使用PowerShell停止虚拟机因为WMI库中exec_method实现有问题
let output = Command::new("powershell.exe")
.args(["-Command", &format!("$OutputEncoding = [System.Text.Encoding]::UTF8; Get-CimInstance -Namespace root\\virtualization\\v2 -ClassName Msvm_ComputerSystem -Filter \"Name = '{}'\" | Invoke-CimMethod -MethodName RequestStateChange -Arguments @{{RequestedState=3; Force=$true}}", vmid)])
.output()?;
Ok(String::from_utf8_lossy(&output.stdout).to_string())
}
async fn start_vm(vmid: &str) -> Result<String> {
let wmi_con = get_wmi_connection()?;
// 查询虚拟机
let mut filters = HashMap::new();
filters.insert("Name".to_string(), FilterValue::String(vmid.to_string()));
let vms: Vec<HashMap<String, Variant>> = wmi_con.filtered_query(&filters)?;
if vms.is_empty() {
return Ok("VM not found".to_string());
}
// 使用PowerShell启动虚拟机因为WMI库中exec_method实现有问题
let output = Command::new("powershell.exe")
.args(["-Command", &format!("$OutputEncoding = [System.Text.Encoding]::UTF8; Get-CimInstance -Namespace root\\virtualization\\v2 -ClassName Msvm_ComputerSystem -Filter \"Name = '{}'\" | Invoke-CimMethod -MethodName RequestStateChange -Arguments @{{RequestedState=2}}", vmid)])
.output()?;
Ok(String::from_utf8_lossy(&output.stdout).to_string())
}
async fn list_vm_details(vmid: &str) -> Result<String> {
let wmi_con = get_wmi_connection()?;
// 查询特定虚拟机
let mut filters = HashMap::new();
filters.insert("Name".to_string(), FilterValue::String(vmid.to_string()));
let vms: Vec<HashMap<String, Variant>> = wmi_con.filtered_query(&filters)?;
if vms.is_empty() {
return Ok("{}".to_string());
}
// 转换结果为JSON
let json = serde_json::to_string(&vms[0])?;
Ok(json)
}
async fn list_vm() -> Result<String> {
let wmi_con = get_wmi_connection()?;
// 使用WQL查询虚拟机适应中英文环境
let raw_query = "SELECT * FROM Msvm_ComputerSystem WHERE Caption='虚拟机' OR Caption='Virtual Machine'";
let vms_result: Vec<HashMap<String, Variant>> = wmi_con.raw_query(raw_query)?;
let mut vms = Vec::new();
for vm in vms_result {
let vm_id = vm.get("Name").and_then(variant_to_string).unwrap_or_default();
if vm_id.is_empty() {
continue;
}
// 获取虚拟机状态
let state = match vm.get("EnabledState") {
Some(Variant::I2(2)) => "Running".to_string(),
Some(Variant::I2(3)) => "Stopped".to_string(),
Some(Variant::I2(9)) => "Paused".to_string(),
Some(Variant::I2(state)) => format!("State_{}", state),
Some(Variant::I4(2)) => "Running".to_string(),
Some(Variant::I4(3)) => "Stopped".to_string(),
Some(Variant::I4(9)) => "Paused".to_string(),
Some(Variant::I4(state)) => format!("State_{}", state),
Some(Variant::UI2(2)) => "Running".to_string(),
Some(Variant::UI2(3)) => "Stopped".to_string(),
Some(Variant::UI2(9)) => "Paused".to_string(),
Some(Variant::UI2(state)) => format!("State_{}", *state),
Some(other) => format!("Unknown: {:?}", other),
None => "Unknown".to_string(),
};
// 获取CPU核心数
let cpu_cores = get_vm_cpu_cores(&wmi_con, &vm_id)?;
// 获取内存大小
let memory_mb = get_vm_memory(&wmi_con, &vm_id)?;
// 创建VM对象
vms.push(VirtualMachine {
vmid: vm_id,
name: vm.get("ElementName").and_then(variant_to_string).unwrap_or_default(),
state,
status: vm.get("Status").and_then(variant_to_string).unwrap_or_default(),
cpu_cores,
memory_mb,
});
}
// 转换为JSON
let json = serde_json::to_string(&vms)?;
Ok(json)
}
// 获取虚拟机网络信息
fn get_vm_network_info(wmi_con: &WMIConnection, vm_id: &str) -> Result<IpInfo> {
println!("获取VM {} 的网络信息", vm_id);
// 查询网络适配器
let net_query = format!("SELECT * FROM Msvm_SyntheticEthernetPortSettingData WHERE InstanceID LIKE '%{}%'", vm_id);
let network_adapters: Vec<HashMap<String, Variant>> = wmi_con.raw_query(&net_query)?;
if network_adapters.is_empty() {
println!("没有找到网络适配器");
return Ok(IpInfo {
mac_address: String::new(),
switch_name: String::new(),
ip_addresses: Vec::new(),
});
}
// 获取第一个网络适配器信息
let adapter = &network_adapters[0];
let mac_address = adapter.get("Address").and_then(variant_to_string).unwrap_or_default();
let switch_name = adapter.get("SwitchName").and_then(variant_to_string).unwrap_or_default();
// 尝试获取IP地址
let mut ip_addresses = Vec::new();
// 查询关联的IP配置
if !mac_address.is_empty() {
// 通过虚拟机的管理服务获取IP地址
if is_uuid(vm_id) {
// 首先获取GuestIntrinsicExchangeItems可能包含IP地址
let guestinfo_query = format!(
"ASSOCIATORS OF {{Msvm_ComputerSystem.CreationClassName='Msvm_ComputerSystem',Name='{}'}} \
WHERE AssocClass=Msvm_SystemDevice ResultClass=Msvm_KvpExchangeComponent",
vm_id
);
let guest_info: Vec<HashMap<String, Variant>> = wmi_con.raw_query(&guestinfo_query)?;
if !guest_info.is_empty() {
// 解析GuestExchangeItems以获取IP地址
if let Some(exchange_items) = guest_info[0].get("GuestExchangeItems").and_then(variant_to_string) {
for line in exchange_items.lines() {
if line.contains("NetworkAddressIPv4") || line.contains("NetworkAddressIPv6") {
if let Some(ip) = line.split_whitespace().last() {
if !ip.is_empty() && !ip_addresses.contains(&ip.to_string()) {
ip_addresses.push(ip.to_string());
}
}
}
}
}
}
}
// 如果上面方法没有找到IP尝试其他方式
if ip_addresses.is_empty() {
// 尝试查询连接到具体网络适配器的IP地址信息
let ip_query = format!(
"SELECT * FROM Msvm_GuestNetworkAdapterConfiguration \
WHERE InstanceID LIKE '%{}%'",
mac_address.replace(":", "")
);
let ip_configs: Vec<HashMap<String, Variant>> = wmi_con.raw_query(&ip_query)?;
if !ip_configs.is_empty() {
// 尝试从IPAddresses获取IP
if let Some(Variant::Array(ip_array)) = ip_configs[0].get("IPAddresses") {
for ip_var in ip_array {
if let Some(ip) = variant_to_string(ip_var) {
if !ip.is_empty() && !ip_addresses.contains(&ip) {
ip_addresses.push(ip);
}
}
}
}
}
}
}
println!("网络信息: MAC={}, Switch={}, IPs={:?}", mac_address, switch_name, ip_addresses);
Ok(IpInfo {
mac_address,
switch_name,
ip_addresses,
})
}
// 获取虚拟机CPU核心数
fn get_vm_cpu_cores(wmi_con: &WMIConnection, vm_id: &str) -> Result<u32> {
// 方法1: 通过处理器设置关联查询
let query1 = format!(
"SELECT * FROM Msvm_ProcessorSettingData WHERE InstanceID LIKE '%{}%'",
vm_id
);
let processor_settings1: Vec<HashMap<String, Variant>> = wmi_con.raw_query(&query1)?;
if !processor_settings1.is_empty() {
// 从VirtualQuantity获取核心数
if let Some(cores) = processor_settings1[0].get("VirtualQuantity") {
match cores {
Variant::I4(cores) => return Ok(*cores as u32),
Variant::UI4(cores) => return Ok(*cores),
Variant::I2(cores) => return Ok(*cores as u32),
Variant::UI2(cores) => return Ok(*cores as u32),
Variant::I8(cores) => return Ok(*cores as u32),
Variant::UI8(cores) => return Ok(*cores as u32),
_ => println!("获取CPU核心数失败: {:?}", cores),
}
}
}
// 默认返回1个核心
println!("无法确定CPU核心数使用默认值: 1");
Ok(1)
}
// 获取虚拟机内存大小(MB)
fn get_vm_memory(wmi_con: &WMIConnection, vm_id: &str) -> Result<u64> {
// 方法1: 通过内存设置关联查询
let query1 = format!(
"ASSOCIATORS OF {{Msvm_ComputerSystem.CreationClassName='Msvm_ComputerSystem',Name='{}'}} \
WHERE AssocClass=Msvm_SettingsDefineState ResultClass=Msvm_MemorySettingData",
vm_id
);
let memory_settings1: Vec<HashMap<String, Variant>> = wmi_con.raw_query(&query1)?;
if !memory_settings1.is_empty() {
// 从VirtualQuantity获取内存大小(MB)
if let Some(mem) = memory_settings1[0].get("VirtualQuantity") {
match mem {
Variant::I4(mem) => return Ok(*mem as u64),
Variant::UI4(mem) => return Ok(*mem as u64),
Variant::I8(mem) => return Ok(*mem as u64),
Variant::UI8(mem) => return Ok(*mem),
_ => println!("方法1: 内存大小类型不匹配: {:?}", mem),
}
}
}
// 方法2: 尝试直接查询内存设置
let query2 = format!(
"SELECT * FROM Msvm_MemorySettingData WHERE InstanceID LIKE '%{}%'",
vm_id
);
let memory_settings2: Vec<HashMap<String, Variant>> = wmi_con.raw_query(&query2)?;
if !memory_settings2.is_empty() {
if let Some(mem) = memory_settings2[0].get("VirtualQuantity") {
match mem {
Variant::I4(mem) => return Ok(*mem as u64),
Variant::UI4(mem) => return Ok(*mem as u64),
Variant::I8(mem) => return Ok(*mem as u64),
Variant::UI8(mem) => return Ok(*mem),
_ => println!("方法2: 内存大小类型不匹配: {:?}", mem),
}
}
}
// 方法3: 查询虚拟机配置
let query3 = format!(
"SELECT * FROM Msvm_VirtualSystemSettingData WHERE InstanceID LIKE '%{}%'",
vm_id
);
let vm_settings: Vec<HashMap<String, Variant>> = wmi_con.raw_query(&query3)?;
if !vm_settings.is_empty() {
if let Some(mem) = vm_settings[0].get("MemoryStartup") {
match mem {
Variant::I4(mem) => return Ok(*mem as u64),
Variant::UI4(mem) => return Ok(*mem as u64),
Variant::I8(mem) => return Ok(*mem as u64),
Variant::UI8(mem) => return Ok(*mem),
_ => println!("方法3: 内存大小类型不匹配: {:?}", mem),
}
}
}
// 默认返回1GB
println!("无法确定内存大小,使用默认值: 1024MB");
Ok(1024)
}
// 获取虚拟机快照列表
fn get_vm_snapshots(wmi_con: &WMIConnection, vm_id: &str) -> Result<Vec<Snapshot>> {
println!("获取VM {} 的快照列表", vm_id);
// 查询快照列表
let query = format!(
"SELECT * FROM Msvm_VirtualSystemSettingData \
WHERE VirtualSystemType='Microsoft:Hyper-V:Snapshot:Realized' AND InstanceID LIKE '%{}%'",
vm_id
);
let snapshot_results: Vec<HashMap<String, Variant>> = wmi_con.raw_query(&query)?;
println!("找到 {} 个快照", snapshot_results.len());
let mut snapshots = Vec::new();
for snapshot in snapshot_results {
let name = snapshot.get("ElementName").and_then(variant_to_string).unwrap_or_default();
let creation_time = snapshot.get("CreationTime").and_then(variant_to_string).unwrap_or_default();
if !name.is_empty() {
snapshots.push(Snapshot {
name,
creation_time,
});
}
}
println!("快照列表: {:?}", snapshots.iter().map(|s| &s.name).collect::<Vec<_>>());
Ok(snapshots)
}
async fn get_all() -> Result<String> {
let wmi_con = get_wmi_connection()?;
// 使用WQL查询虚拟机适应中英文环境
let raw_query = "SELECT * FROM Msvm_ComputerSystem WHERE Caption='虚拟机' OR Caption='Virtual Machine'";
let vms: Vec<HashMap<String, Variant>> = wmi_con.raw_query(raw_query)?;
// 转换为JSON
let json = serde_json::to_string(&vms)?;
Ok(json)
}
async fn get_network(vmid: Option<&str>) -> Result<String> {
let wmi_con = get_wmi_connection()?;
let mut adapters = Vec::new();
// 根据是否有VMID选择查询方式
if let Some(id) = vmid {
// 先获取特定虚拟机
let mut vm_filters = HashMap::new();
vm_filters.insert("Name".to_string(), FilterValue::String(id.to_string()));
let vms: Vec<HashMap<String, Variant>> = wmi_con.filtered_query(&vm_filters)?;
if !vms.is_empty() {
// 查询该虚拟机的网络适配器
let wql = format!("SELECT * FROM Msvm_SyntheticEthernetPortSettingData WHERE InstanceID LIKE '%{}%'", id);
let network_adapters: Vec<HashMap<String, Variant>> = wmi_con.raw_query(&wql)?;
for adapter in network_adapters {
adapters.push(NetworkAdapter {
macaddress: adapter.get("Address").and_then(variant_to_string).unwrap_or_default(),
vmid: id.to_string(),
switchname: adapter.get("SwitchName").and_then(variant_to_string).unwrap_or_default(),
ipaddresses: Vec::new(), // WMI中不直接提供IP地址需要额外查询
});
}
}
} else {
// 查询所有网络适配器
let adapters_result: Vec<HashMap<String, Variant>> = wmi_con.query()?;
for adapter in adapters_result {
// 从实例ID中提取VMID
let instance_id = adapter.get("InstanceID").and_then(variant_to_string).unwrap_or_default();
let vmid = extract_uuid_from_instance_id(&instance_id);
adapters.push(NetworkAdapter {
macaddress: adapter.get("Address").and_then(variant_to_string).unwrap_or_default(),
vmid,
switchname: adapter.get("SwitchName").and_then(variant_to_string).unwrap_or_default(),
ipaddresses: Vec::new(),
});
}
}
// 转换为JSON
let json = serde_json::to_string(&adapters)?;
Ok(json)
}
// 从实例ID中提取UUID
fn extract_uuid_from_instance_id(instance_id: &str) -> String {
let re_str = r"\{[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}\}";
let re = regex::Regex::new(re_str).unwrap();
if let Some(captures) = re.captures(instance_id) {
if let Some(m) = captures.get(0) {
return m.as_str().to_string();
}
}
"".to_string()
}
async fn get_snapshot(vmid: Option<&str>) -> Result<String> {
let wmi_con = get_wmi_connection()?;
let mut snapshots = Vec::new();
let wql = if let Some(id) = vmid {
format!("SELECT * FROM Msvm_VirtualSystemSettingData WHERE VirtualSystemType='Microsoft:Hyper-V:Snapshot:Realized' AND InstanceID LIKE '%{}%'", id)
} else {
"SELECT * FROM Msvm_VirtualSystemSettingData WHERE VirtualSystemType='Microsoft:Hyper-V:Snapshot:Realized'".to_string()
};
let snapshot_results: Vec<HashMap<String, Variant>> = wmi_con.raw_query(&wql)?;
for snapshot in snapshot_results {
// 提取快照信息
snapshots.push(Snapshot {
name: snapshot.get("ElementName").and_then(variant_to_string).unwrap_or_default(),
creation_time: snapshot.get("CreationTime").and_then(variant_to_string).unwrap_or_default(),
});
}
// 转换为JSON
let json = serde_json::to_string(&snapshots)?;
Ok(json)
}
async fn check_vmid(vmid: &str) -> Result<bool> {
if !is_uuid(vmid) {
return Ok(false);
}
let wmi_con = get_wmi_connection()?;
// 使用原始WQL查询直接按名称查找
let raw_query = format!("SELECT * FROM Msvm_ComputerSystem WHERE Name='{}'", vmid);
let vms: Vec<HashMap<String, Variant>> = wmi_con.raw_query(&raw_query)?;
if !vms.is_empty() {
return Ok(true);
}
// 如果WMI查询未找到尝试通过PowerShell检查
println!("WMI未找到VM尝试PowerShell: {}", vmid);
let output = Command::new("powershell.exe")
.args(["-Command", &format!("$OutputEncoding = [System.Text.Encoding]::UTF8; Get-VM -Id '{}' -ErrorAction SilentlyContinue", vmid)])
.output()?;
let success = output.status.success() && !String::from_utf8_lossy(&output.stdout).trim().is_empty();
println!("PowerShell结果: {}", success);
Ok(success)
}
async fn check_os() -> Result<()> {
let wmi_result = get_wmi_connection();
if wmi_result.is_err() {
anyhow::bail!("Hyper-V WMI namespace not available. Hyper-V may not be installed or running");
}
Ok(())
}
async fn list_vm_details_handler(query: web::Query<std::collections::HashMap<String, String>>, req: actix_web::HttpRequest, config: web::Data<Config>) -> impl Responder {
let key = req.headers().get("Key").and_then(|v| v.to_str().ok());
if key != Some(&config.key) {
return HttpResponse::Unauthorized().json(serde_json::json!({ "error": "Key error" }));
}
let vmid = query.get("VMID");
match vmid {
Some(id) => {
if let Ok(true) = check_vmid(id).await {
match list_vm_details(id).await {
Ok(result) => HttpResponse::Ok().body(result),
Err(_) => HttpResponse::InternalServerError().json(serde_json::json!({ "error": "Failed to get VM details" })),
}
} else {
HttpResponse::BadRequest().json(serde_json::json!({ "error": "Invalid VMID" }))
}
},
None => HttpResponse::BadRequest().json(serde_json::json!({ "error": "Need VMID" })),
}
}
async fn get_all_handler(req: actix_web::HttpRequest, config: web::Data<Config>) -> impl Responder {
let key = req.headers().get("Key").and_then(|v| v.to_str().ok());
if key != Some(&config.key) {
return HttpResponse::Unauthorized().json(serde_json::json!({ "error": "Key error" }));
}
match get_all().await {
Ok(result) => HttpResponse::Ok().body(result),
Err(_) => HttpResponse::InternalServerError().json(serde_json::json!({ "error": "Failed to get VM details" })),
}
}
async fn get_network_handler(query: web::Query<std::collections::HashMap<String, String>>, req: actix_web::HttpRequest, config: web::Data<Config>) -> impl Responder {
let key = req.headers().get("Key").and_then(|v| v.to_str().ok());
if key != Some(&config.key) {
return HttpResponse::Unauthorized().json(serde_json::json!({ "error": "Key error" }));
}
let vmid = query.get("VMID").map(|s| s.as_str());
match get_network(vmid).await {
Ok(result) => HttpResponse::Ok().body(result),
Err(_) => HttpResponse::InternalServerError().json(serde_json::json!({ "error": "Failed to get network details" })),
}
}
async fn get_snapshot_handler(query: web::Query<std::collections::HashMap<String, String>>, req: actix_web::HttpRequest, config: web::Data<Config>) -> impl Responder {
let key = req.headers().get("Key").and_then(|v| v.to_str().ok());
if key != Some(&config.key) {
return HttpResponse::Unauthorized().json(serde_json::json!({ "error": "Key error" }));
}
let vmid = query.get("VMID").map(|s| s.as_str());
match get_snapshot(vmid).await {
Ok(result) => HttpResponse::Ok().body(result),
Err(_) => HttpResponse::InternalServerError().json(serde_json::json!({ "error": "Failed to get snapshot details" })),
}
}
async fn stop_vm_handler(query: web::Query<std::collections::HashMap<String, String>>, req: actix_web::HttpRequest, config: web::Data<Config>) -> impl Responder {
let key = req.headers().get("Key").and_then(|v| v.to_str().ok());
if key != Some(&config.key) {
return HttpResponse::Unauthorized().json(serde_json::json!({ "error": "Key error" }));
}
let vmid = query.get("VMID");
match vmid {
Some(id) => {
if let Ok(true) = check_vmid(id).await {
match stop_vm(id).await {
Ok(result) => HttpResponse::Ok().body(result),
Err(_) => HttpResponse::InternalServerError().json(serde_json::json!({ "error": "Failed to stop VM" })),
}
} else {
HttpResponse::BadRequest().json(serde_json::json!({ "error": "Invalid VMID" }))
}
},
None => HttpResponse::BadRequest().json(serde_json::json!({ "error": "Need VMID" })),
}
}
async fn start_vm_handler(query: web::Query<std::collections::HashMap<String, String>>, req: actix_web::HttpRequest, config: web::Data<Config>) -> impl Responder {
let key = req.headers().get("Key").and_then(|v| v.to_str().ok());
if key != Some(&config.key) {
return HttpResponse::Unauthorized().json(serde_json::json!({ "error": "Key error" }));
}
let vmid = query.get("VMID");
match vmid {
Some(id) => {
if let Ok(true) = check_vmid(id).await {
match start_vm(id).await {
Ok(result) => HttpResponse::Ok().body(result),
Err(_) => HttpResponse::InternalServerError().json(serde_json::json!({ "error": "Failed to start VM" })),
}
} else {
HttpResponse::BadRequest().json(serde_json::json!({ "error": "Invalid VMID" }))
}
},
None => HttpResponse::BadRequest().json(serde_json::json!({ "error": "Need VMID" })),
}
}
async fn list_vm_handler(req: actix_web::HttpRequest, config: web::Data<Config>) -> impl Responder {
let key = req.headers().get("Key").and_then(|v| v.to_str().ok());
if key != Some(&config.key) {
return HttpResponse::Unauthorized().json(serde_json::json!({ "error": "Key error" }));
}
match list_vm().await {
Ok(result) => HttpResponse::Ok().body(result),
Err(_) => HttpResponse::InternalServerError().json(serde_json::json!({ "error": "Failed to list VMs" })),
}
}
async fn get_vm_ticket_handler(query: web::Query<std::collections::HashMap<String, String>>, req: actix_web::HttpRequest, config: web::Data<Config>) -> impl Responder {
let key = req.headers().get("Key").and_then(|v| v.to_str().ok());
if key != Some(&config.key) {
return HttpResponse::Unauthorized().json(serde_json::json!({ "error": "Key error" }));
}
let vmid = query.get("VMID");
match vmid {
Some(id) => {
if let Ok(true) = check_vmid(id).await {
match generate_vm_ticket(id).await {
Ok((username, password)) => {
HttpResponse::Ok().json(serde_json::json!({
"username": username,
"password": password
}))
},
Err(e) => HttpResponse::InternalServerError().json(serde_json::json!({
"error": format!("Failed to generate VM ticket: {}", e)
})),
}
} else {
HttpResponse::BadRequest().json(serde_json::json!({ "error": "Invalid VMID" }))
}
},
None => HttpResponse::BadRequest().json(serde_json::json!({ "error": "Need VMID" })),
}
}
fn is_uuid(s: &str) -> bool {
Uuid::parse_str(s).is_ok()
}
async fn generate_vm_ticket(_vmid: &str) -> Result<(String, String)> {
// 生成16位随机用户名
let username: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(16)
.map(char::from)
.collect();
// 生成16位满足Windows密码复杂度的密码
let mut rng = rand::thread_rng();
let special_chars = "!@#$%^&*()_-+=<>?";
let uppercase_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
let lowercase_letters = "abcdefghijklmnopqrstuvwxyz";
let numbers = "0123456789";
let mut password = String::with_capacity(16);
// 确保密码包含至少一个特殊字符、一个大写字母、一个小写字母和一个数字
password.push(special_chars.chars().nth(rng.gen_range(0..special_chars.len())).unwrap());
password.push(uppercase_letters.chars().nth(rng.gen_range(0..uppercase_letters.len())).unwrap());
password.push(lowercase_letters.chars().nth(rng.gen_range(0..lowercase_letters.len())).unwrap());
password.push(numbers.chars().nth(rng.gen_range(0..numbers.len())).unwrap());
// 填充剩余字符
let all_chars = format!("{}{}{}{}", special_chars, uppercase_letters, lowercase_letters, numbers);
for _ in 0..12 {
password.push(all_chars.chars().nth(rng.gen_range(0..all_chars.len())).unwrap());
}
// 打乱密码字符顺序
let mut password_chars: Vec<char> = password.chars().collect();
for i in (1..password_chars.len()).rev() {
let j = rng.gen_range(0..=i);
password_chars.swap(i, j);
}
let password: String = password_chars.into_iter().collect();
// 创建临时用户并添加到Hyper-V管理员组一分钟后自动删除
let create_user_cmd = format!(
"New-LocalUser -Name '{}' -Password (ConvertTo-SecureString '{}' -AsPlainText -Force) -Description 'Temporary Hyper-V access'; Add-LocalGroupMember -Group 'Hyper-V Administrators' -Member '{}'; Start-Job -ScriptBlock {{ Start-Sleep -Seconds 60; Remove-LocalUser -Name '{}' }}",
username, password, username, username
);
let output = Command::new("powershell.exe")
.args(["-Command", &create_user_cmd])
.output()?;
if !output.status.success() {
anyhow::bail!("Failed to create temporary user: {}", String::from_utf8_lossy(&output.stderr));
}
Ok((username, password))
}
#[actix_web::main]
async fn main() -> std::io::Result<()> {
// 检查管理员权限
if !is_admin() {
println!("需要管理员权限才能运行此程序,正在尝试提权...");
if let Err(e) = elevate() {
eprintln!("提权失败: {}", e);
return Err(std::io::Error::new(std::io::ErrorKind::PermissionDenied, "需要管理员权限"));
}
}
// 检查Hyper-V是否可用
if let Err(e) = check_os().await {
eprintln!("错误: {}", e);
return Err(std::io::Error::new(std::io::ErrorKind::Other, e.to_string()));
}
// 加载配置
let config = match load_config() {
Ok(config) => config,
Err(e) => {
eprintln!("Error loading config: {}", e);
std::process::exit(1);
}
};
let port = config.port;
let host = config.host.clone();
let tls_enabled = config.tls;
let cert_path = config.cert_path.clone();
let key_path = config.key_path.clone();
let server = HttpServer::new(move || {
App::new()
.app_data(web::Data::new(config.clone()))
.route("/api2/ListVMDetils", web::get().to(list_vm_details_handler))
.route("/api2/GetALL", web::get().to(get_all_handler))
.route("/api2/GetNetWork", web::get().to(get_network_handler))
.route("/api2/GetSnapShot", web::get().to(get_snapshot_handler))
.route("/api2/StopVM", web::post().to(stop_vm_handler))
.route("/api2/StartVM", web::post().to(start_vm_handler))
.route("/api2/ListVM", web::get().to(list_vm_handler))
.route("/api2/getvmticket", web::get().to(get_vm_ticket_handler))
});
let server = if tls_enabled {
match (&cert_path, &key_path) {
(Some(cert_path), Some(key_path)) => {
match load_rustls_config(cert_path, key_path) {
Ok(rustls_config) => {
println!("API service running at https://{}:{}", host, port);
server.bind_rustls((host, port), rustls_config)?
},
Err(e) => {
eprintln!("Error loading TLS configuration: {}", e);
std::process::exit(1);
}
}
},
_ => {
eprintln!("TLS is enabled but cert_path or key_path is missing in config");
std::process::exit(1);
}
}
} else {
println!("API service running at http://{}:{}", host, port);
server.bind((host, port))?
};
server.run().await
}