MNIST(숫자) 이미지 분류페이지 만들기

기본 코드

# import
from fastapi import FastAPI, Request
import uvicorn
from fastapi.templating import Jinja2Templates

# fastapi 선언
app = FastAPI()

# template 폴더 마운트
templates = Jinja2Templates(directory="templates")


# 유비콘 if문
if __name__ == "__main__":
  uvicorn.run(app,host="0.0.0.0", port=8000)


페이지

# 메인페이지
@app.get("/")
def hello(request: Request):
  return templates.TemplateResponse('index.html', \
                                    context = {"request":request})

# 업로드 처리 함수 및 추론 함수 호출
from fastapi import File, UploadFile
import os
@app.post('/uploader')
#비동기 함수 (안전성)
async def uploader_file(request: Request, file: UploadFile = File(...)):
  content = await file.read()
  with open(f"./{file.filename}", 'wb') as fp:
    fp.write(content)

  # 자 파일을 받았으니 추론해보자
  output = infer(f"./{file.filename}")
  return templates.TemplateResponse("CNN_result.html",
                                    {'request': request, 'result': output})


모델 클래스 정의

import torch
# 모델클래스 정의
class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()

        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2))
       
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2))
       
        # 전결합층 7x7x64 inputs -> 10 outputs
        self.fc = torch.nn.Linear(7 * 7 * 64, 10, bias=True)
        # 전결합층 한정으로 가중치 초기화
        torch.nn.init.xavier_uniform_(self.fc.weight)
       
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)   # 전결합층을 위해서 Flatten
        out = self.fc(out)
        return out


추론 함수

from PIL import Image
import torchvision.transforms as transforms

# 추론함수
def infer(filename):
   # 다시불러와서 추론 해보기
    model = CNN()
    model.load_state_dict(torch.load("cnn_model.pt", map_location=torch.device('cpu')))  # cpu로 해야해서
    model.eval() #평가 모드로 설정하여야 합니다. 이 과정을 거치지 않으면 일관성 없는 추론 결과가 출력
    # 학습을 진행하지 않을 것이므로 torch.no_grad(), gradient descent를 하지마라고 명령내리는 것
    with torch.no_grad():
        # 이미지 파일 경로 설정
        img = Image.open(filename)
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=1), # RGB(3D) -> Gray(2D)
            transforms.Resize((28, 28)), # 모델 인풋에 맞게
            transforms.ToTensor(), # 토치 텐서 타입으로 맞춰줘야함
        ])

        img_tensor = transform(img) # [1, 28, 28]
        img_tensor = img_tensor.unsqueeze(0) # [1, 1, 28, 28]

        print(img_tensor.shape)

        prediction = model(img_tensor)
                            # CNN은 10개의 아웃풋으로 각 10개의 클래스에 대한 피처값이 나온다, 이를 axis 1방향으로 max값을 찾는다는 것
        result = torch.argmax(prediction, 1) #tensor([결과])
        result = result.tolist()[0] # 결과 라고 나오도록
        return result



댓글 쓰기

다음 이전