Administrator
发布于 2025-11-27 / 3 阅读
0
0

使用resnet50训练花卉识别ai

import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import re
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

class FlowerDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.transform = transform
        self.data = []
        self.class_to_idx = {}
        self.idx_to_class = {}
        self._parse_csv(csv_file)
        
    def _parse_csv(self, csv_file):
        """解析CSV文件"""
        try:
            df = pd.read_csv(csv_file)
            
            required_columns = ['filename', 'category_id', 'chinese_name', 'english_name']
            if not all(col in df.columns for col in required_columns):
                df = pd.read_csv(csv_file, header=None)
                if len(df.columns) >= 4:
                    df.columns = ['filename', 'category_id', 'chinese_name', 'english_name'][:len(df.columns)]
                else:
                    raise ValueError("CSV文件列数不足")
            
        except Exception as e:
            print(f"使用pandas读取失败,尝试手动解析: {e}")
            self._parse_csv_manual(csv_file)
            return
        
        for _, row in df.iterrows():
            filename = str(row['filename']).strip()
            if not filename.lower().endswith('.jpg'):
                continue
                
            try:
                category_id = int(row['category_id'])
                chinese_name = str(row['chinese_name']).strip()
                english_name = str(row['english_name']).strip()
                
                self.data.append({
                    'filename': filename,
                    'category_id': category_id,
                    'chinese_name': chinese_name,
                    'english_name': english_name
                })
            except (ValueError, KeyError) as e:
                continue
        
        if len(self.data) == 0:
            self._parse_csv_manual(csv_file)
            return
            
        self._build_category_mappings()
        
    def _parse_csv_manual(self, csv_file):
        """手动解析CSV文件"""
        with open(csv_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        
        for line_num, line in enumerate(lines):
            line = line.strip()
            if not line or line.startswith('filename'):
                continue
                
            parts = line.split(',')
            if len(parts) < 4:
                continue
                
            filename = parts[0].strip()
            if not filename.lower().endswith('.jpg'):
                continue
                
            try:
                category_id = int(parts[1].strip())
                chinese_name = parts[2].strip()
                english_name = parts[3].strip()
                
                self.data.append({
                    'filename': filename,
                    'category_id': category_id,
                    'chinese_name': chinese_name,
                    'english_name': english_name
                })
            except (ValueError, IndexError) as e:
                continue
        
        self._build_category_mappings()
    
    def _build_category_mappings(self):
        """构建类别映射"""
        if len(self.data) == 0:
            print("警告: 没有找到有效数据")
            return
            
        unique_categories = sorted(set([item['category_id'] for item in self.data]))
        self.class_to_idx = {cls: idx for idx, cls in enumerate(unique_categories)}
        self.idx_to_class = {idx: cls for cls, idx in self.class_to_idx.items()}
        
        self.category_names = {}
        for item in self.data:
            cat_id = item['category_id']
            if cat_id not in self.category_names:
                self.category_names[cat_id] = {
                    'chinese': item['chinese_name'],
                    'english': item['english_name']
                }
        
        print(f"成功加载 {len(self.data)} 张图片,{len(unique_categories)} 个类别")
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        filename = item['filename']
        
        if not os.path.exists(filename):
            basename = os.path.basename(filename)
            if os.path.exists(basename):
                img_path = basename
            else:
                possible_paths = [
                    f"./images/{basename}",
                    f"./train/{basename}",
                    f"./data/{basename}",
                    basename
                ]
                for path in possible_paths:
                    if os.path.exists(path):
                        img_path = path
                        break
                else:
                    raise FileNotFoundError(f"图片文件不存在: {filename}")
        else:
            img_path = filename
        
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            raise FileNotFoundError(f"无法打开图片 {img_path}: {e}")
        
        if self.transform:
            image = self.transform(image)
        
        label = self.class_to_idx[item['category_id']]
        return image, label, filename
    
    def get_category_info(self, category_id):
        return self.category_names.get(category_id, {'chinese': '未知', 'english': 'Unknown'})

class ModelCheckpoint:
    """模型检查点类,用于保存最佳模型"""
    def __init__(self, filepath, patience=5, verbose=True, delta=0.001):
        self.filepath = filepath
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.best_score = None
        self.epochs_no_improve = 0
        self.early_stop = False
        
    def __call__(self, score, model, optimizer, scheduler, epoch, dataset_info):
        if self.best_score is None:
            self._save_checkpoint(score, model, optimizer, scheduler, epoch, dataset_info)
            self.best_score = score
        elif score < self.best_score + self.delta:  # 新分数没有显著提升
            self.epochs_no_improve += 1
            if self.verbose:
                print(f'准确率未提升,耐心计数: {self.epochs_no_improve}/{self.patience}')
            
            if self.epochs_no_improve >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print(f'早停触发! 最佳准确率: {self.best_score:.4f}%')
        else:  # 新分数有提升
            self._save_checkpoint(score, model, optimizer, scheduler, epoch, dataset_info)
            self.best_score = score
            self.epochs_no_improve = 0
            if self.verbose:
                print(f'准确率提升: {self.best_score:.4f}%, 保存新模型')
    
    def _save_checkpoint(self, score, model, optimizer, scheduler, epoch, dataset_info):
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'best_score': score,
            'class_to_idx': dataset_info['class_to_idx'],
            'idx_to_class': dataset_info['idx_to_class'],
            'category_names': dataset_info['category_names'],
            'num_classes': dataset_info['num_classes']
        }, self.filepath)

class LearningRateScheduler:
    """学习率调度器"""
    def __init__(self, optimizer, initial_lr=0.001, patience=3, factor=0.5, min_lr=1e-6):
        self.optimizer = optimizer
        self.lr = initial_lr
        self.patience = patience
        self.factor = factor
        self.min_lr = min_lr
        self.best_score = None
        self.epochs_no_improve = 0
        
    def step(self, score):
        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + 0.001:  # 微小提升不计
            self.epochs_no_improve += 1
            if self.epochs_no_improve >= self.patience:
                self._reduce_lr()
                self.epochs_no_improve = 0
        else:
            self.best_score = score
            self.epochs_no_improve = 0
    
    def _reduce_lr(self):
        old_lr = self.lr
        self.lr = max(self.lr * self.factor, self.min_lr)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr
        print(f'学习率从 {old_lr:.6f} 降低到 {self.lr:.6f}')

class FlowerClassifier:
    def __init__(self, num_classes, learning_rate=0.001, checkpoint_path=None):
        self.model = models.resnet50(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
        self.model = self.model.to(device)
        
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        self.scheduler = LearningRateScheduler(self.optimizer, initial_lr=learning_rate)
        
        # 如果提供了检查点路径,加载模型
        if checkpoint_path and os.path.exists(checkpoint_path):
            self.load_checkpoint(checkpoint_path)
    
    def load_checkpoint(self, checkpoint_path):
        """加载检查点"""
        try:
            checkpoint = torch.load(checkpoint_path, map_location=device)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            if checkpoint.get('scheduler_state_dict'):
                # 对于标准调度器
                if hasattr(self, 'standard_scheduler'):
                    self.standard_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            
            print(f"成功加载检查点: {checkpoint_path}")
            print(f"检查点信息: 最佳准确率 = {checkpoint.get('best_score', '未知')}")
            
        except Exception as e:
            print(f"加载检查点失败: {e}")
            print("将从头开始训练")
    
    def train_epoch(self, dataloader):
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        pbar = tqdm(dataloader, desc="训练中")
        for batch_idx, (inputs, labels, _) in enumerate(pbar):
            inputs, labels = inputs.to(device), labels.to(device)
            
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({
                '损失': f'{running_loss/(batch_idx+1):.4f}',
                '准确率': f'{100.*correct/total:.4f}%'
            })
        
        epoch_loss = running_loss / len(dataloader)
        epoch_acc = 100. * correct / total
        return epoch_loss, epoch_acc
    
    def evaluate(self, dataloader):
        self.model.eval()
        correct = 0
        total = 0
        errors = []
        
        with torch.no_grad():
            pbar = tqdm(dataloader, desc="评估中")
            for inputs, labels, filenames in pbar:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = self.model(inputs)
                _, predicted = outputs.max(1)
                
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
                
                # 记录错误样本
                for i in range(len(labels)):
                    if predicted[i] != labels[i]:
                        errors.append({
                            'filename': filenames[i],
                            'true_label': labels[i].item(),
                            'predicted_label': predicted[i].item()
                        })
                
                pbar.set_postfix({
                    '准确率': f'{100.*correct/total:.4f}%'
                })
        
        accuracy = 100. * correct / total
        return accuracy, errors

class TrainingManager:
    """训练管理器"""
    def __init__(self, model, dataset, max_epochs=50):
        self.model = model
        self.dataset = dataset
        self.max_epochs = max_epochs
        
        # 创建检查点
        self.checkpoint = ModelCheckpoint(
            filepath='best_flower_classifier.pth',
            patience=8,  # 增加耐心值
            verbose=True,
            delta=0.001  # 微小提升阈值
        )
        
        # 数据集信息
        self.dataset_info = {
            'class_to_idx': dataset.class_to_idx,
            'idx_to_class': dataset.idx_to_class,
            'category_names': dataset.category_names,
            'num_classes': len(dataset.class_to_idx)
        }
    
    def train(self, dataloader):
        print("开始训练流程...")
        
        for epoch in range(self.max_epochs):
            print(f"\n=== 第 {epoch + 1} 轮训练 ===")
            
            # 训练一个epoch
            train_loss, train_acc = self.model.train_epoch(dataloader)
            print(f"训练损失: {train_loss:.4f}, 训练准确率: {train_acc:.4f}%")
            
            # 评估
            accuracy, errors = self.model.evaluate(dataloader)
            print(f"评估准确率: {accuracy:.4f}%")
            
            # 更新学习率
            self.model.scheduler.step(accuracy)
            
            # 保存检查点(只在准确率提升时保存)
            self.checkpoint(accuracy, self.model.model, self.model.optimizer, 
                          None, epoch, self.dataset_info)
            
            # 显示错误分析
            self._show_error_analysis(errors)
            
            # 检查早停
            if self.checkpoint.early_stop:
                print("训练早停")
                break
            
            # 连续正确判断
            consecutive_correct = self._check_consecutive_correct(errors, epoch)
            if consecutive_correct >= 3:
                print("训练成功! 连续三轮无错误识别。")
                break
    
    def _show_error_analysis(self, errors):
        """显示错误分析"""
        if len(errors) == 0:
            print("完美识别!")
        else:
            print(f"发现 {len(errors)} 个识别错误")
            
            # 显示部分错误样本
            print("\n错误样本示例:")
            for i, error in enumerate(errors[:3]):  # 只显示前3个错误
                true_cat_id = self.dataset.idx_to_class[error['true_label']]
                pred_cat_id = self.dataset.idx_to_class[error['predicted_label']]
                true_info = self.dataset.get_category_info(true_cat_id)
                pred_info = self.dataset.get_category_info(pred_cat_id)
                
                print(f"{error['filename']}: 正确={true_info['chinese']}({true_cat_id}), "
                      f"预测={pred_info['chinese']}({pred_cat_id})")
    
    def _check_consecutive_correct(self, errors, epoch):
        """检查连续正确次数"""
        if len(errors) == 0:
            consecutive_correct = getattr(self, 'consecutive_correct', 0) + 1
            print(f"完美识别! 连续正确次数: {consecutive_correct}/3")
        else:
            consecutive_correct = 0
        
        self.consecutive_correct = consecutive_correct
        return consecutive_correct

def main():
    # 数据预处理
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # 创建数据集
    try:
        dataset = FlowerDataset('train_labels.csv', transform=transform)
    except Exception as e:
        print(f"加载数据集失败: {e}")
        return
    
    if len(dataset) == 0:
        print("没有找到有效数据")
        return
    
    # 创建数据加载器
    dataloader = DataLoader(
        dataset, 
        batch_size=16, 
        shuffle=True, 
        num_workers=4,
        pin_memory=True
    )
    
    # 检查是否存在最佳模型
    checkpoint_path = 'best_flower_classifier.pth'
    resume_training = os.path.exists(checkpoint_path)
    
    if resume_training:
        print(f"发现已有最佳模型 {checkpoint_path},将从检查点继续训练")
    else:
        print("未找到最佳模型,将从头开始训练")
    
    # 创建分类器
    classifier = FlowerClassifier(
        num_classes=len(dataset.class_to_idx), 
        checkpoint_path=checkpoint_path if resume_training else None
    )
    
    # 创建训练管理器
    training_manager = TrainingManager(classifier, dataset, max_epochs=50)
    
    # 开始训练
    training_manager.train(dataloader)
    
    print(f"\n训练完成! 最佳模型已保存为 'best_flower_classifier.pth'")

if __name__ == "__main__":
    if not os.path.exists('train_labels.csv'):
        print("错误: 找不到 train_labels.csv 文件")
        print("请确保CSV文件与Python脚本在同一目录下")
    else:
        # 查看CSV文件的前几行内容
        print("CSV文件前3行内容:")
        with open('train_labels.csv', 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                if i < 3:
                    print(f"第{i+1}行: {line.strip()}")
                else:
                    break
        print()
        
        main()


评论