2019년 1월 28일 월요일

pytorch로 multi task learning 학습 할 때 주의해야 할 점

v2019.01.28

pytorch를 사용하면서 구글링으로 해결하기 어려웠던 내용들을 정리해본다.

우선,
multi task learning으로 학습을 해야 할 때가 있다.
(혹은 RL에서 Auxiliary task)

이 때 optimizer를 선언하는 방법은 아래와 같이 하면 된다.

optimizer = torch.optim.Adam(list(model1.parameters())+list(model2.parameters())+...)


두번째로
loss를 각각 정의 해준 뒤 하나로 합쳐준다.

optimizer.zero_grad()      # clear gradient
total_loss = model1_loss + model2_loss + ...     #calculate and summation loss
total_loss.backward(retain_graph=True)  # backpropagation, compute gradient
optimizer.step()        #update

loss마다 각각 backward()를 했을 때는 에러로 retain_graph=True를 넣으라고 뜨지만 넣어줘도 같은 에러가 뜨고 죽는다. loss를 합쳐서 한번에 backward 하면 죽지 않는다.

retain_graph를 True로 설정하면 backward를 해도 gradient를 남겨놓음으로써 이후에 backward를 할 수 있게 되며, 나의 경우 retrain graph=True로 설정하면 약 60mb정도의 메모리를 더 소요하는 것 외에 큰 차이를 보지 못했다. (pytorch docs에는 효율적인 사용을 위해 retain_graph 파라미터에 대해 따로 건들 일이 거의 없다고 되어있다)

댓글 없음:

댓글 쓰기