最优传输学习及问题总结

文章目录

  • 参考内容
  • lam=0.1
  • lam=3
  • lam=10
  • lam=50
  • lam=100
  • lam=300
  • 画图
  • 线性规划
    • matlab
    • python代码

欢迎关注我们组的微信公众号,更多好文章在等你呦!
微信公众号名:碳硅数据
公众号二维码:
在这里插入图片描述

参考内容

https://blog.csdn.net/qq_41129489/article/details/128830589
https://zhuanlan.zhihu.com/p/542379144

我主要想强调的是这个例子的解法存在的一些细节问题

lam=0.1

lam = 0.1P, d = compute_optimal_transport(M,r,c, lam=lam)partition = pd.DataFrame(P, index=np.arange(1, 9), columns=np.arange(1, 6))
ax = partition.plot(kind='bar', stacked=True)
print('Sinkhorn distance: {}'.format(d))
ax.set_ylabel('portions')
ax.set_title('Optimal distribution ($\lambda={}$)'.format(lam))print(d)PP = np.around(P,3) 
print(PP)print("*"*100)
print(np.sum(PP,axis=0))
print(np.sum(PP,axis=1))
print("*"*100)
## 这个很接近了

结果如下
在这里插入图片描述

lam=3

lam = 3P, d = compute_optimal_transport(M,r,c, lam=lam)partition = pd.DataFrame(P, index=np.arange(1, 9), columns=np.arange(1, 6))
ax = partition.plot(kind='bar', stacked=True)
print('Sinkhorn distance: {}'.format(d))
ax.set_ylabel('portions')
ax.set_title('Optimal distribution ($\lambda={}$)'.format(lam))
print(d)PP = np.around(P,3) 
print(PP)print("*"*100)
print(np.sum(PP,axis=0))
print(np.sum(PP,axis=1))
print("*"*100)
## 这个很接近了

在这里插入图片描述

lam=10

lam = 10P, d = compute_optimal_transport(M,r,c, lam=lam)partition = pd.DataFrame(P, index=np.arange(1, 9), columns=np.arange(1, 6))
ax = partition.plot(kind='bar', stacked=True)
print('Sinkhorn distance: {}'.format(d))
ax.set_ylabel('portions')
ax.set_title('Optimal distribution ($\lambda={}$)'.format(lam))print(d)PP = np.around(P,3) 
print(PP)print("*"*100)
print(np.sum(PP,axis=0))
print(np.sum(PP,axis=1))
print("*"*100)
## 这个很接近了

在这里插入图片描述

lam=50

lam = 50P, d = compute_optimal_transport(M,r,c, lam=lam)partition = pd.DataFrame(P, index=np.arange(1, 9), columns=np.arange(1, 6))
ax = partition.plot(kind='bar', stacked=True)
print('Sinkhorn distance: {}'.format(d))
ax.set_ylabel('portions')
ax.set_title('Optimal distribution ($\lambda={}$)'.format(lam))PP = np.around(P,3) 
print(PP)print("*"*100)
print(np.sum(PP,axis=0))
print(np.sum(PP,axis=1))
print("*"*100)
## 这个很接近了

在这里插入图片描述

lam=100

lam = 100P, d = compute_optimal_transport(M,r,c, lam=lam)partition = pd.DataFrame(P, index=np.arange(1, 9), columns=np.arange(1, 6))
ax = partition.plot(kind='bar', stacked=True)
print('Sinkhorn distance: {}'.format(d))
ax.set_ylabel('portions')
ax.set_title('Optimal distribution ($\lambda={}$)'.format(lam))print(d)
PP = np.around(P,3) 
print(PP)print("*"*100)
print(np.sum(PP,axis=0))
print(np.sum(PP,axis=1))
print("*"*100)
## 这个很接近了

在这里插入图片描述

lam=300

lam = 300P, d = compute_optimal_transport(M,r,c, lam=lam)partition = pd.DataFrame(P, index=np.arange(1, 9), columns=np.arange(1, 6))
ax = partition.plot(kind='bar', stacked=True)
print('Sinkhorn distance: {}'.format(d))
ax.set_ylabel('portions')
ax.set_title('Optimal distribution ($\lambda={}$)'.format(lam))print(d)
PP = np.around(P,3) 
print(PP)print("*"*100)
print(np.sum(PP,axis=0))
print(np.sum(PP,axis=1))
print("*"*100)
## 这个就不接近了,之前的求和都是相差在0.001左右,可以近似看作相等
## 但是这个行和是 [2.    1.714 3.75  2.286 2.5   2.5   4.    1.25 ]
## 很明显是 [3. 3. 3. 4. 2. 2. 2. 1.]这个是不对的,所以lam=300时这个值已经发散了,
## 虽然此时的Sinkhorn distance是小于24的,但也不起作用

在这里插入图片描述

画图

import numpy as np
import pandas as pd
import matplotlib.pyplot as pltdef compute_optimal_transport(M=None, r=None, c=None, lam=None, eplison=1e-8):"""Computes the optimal transport matrix and Slinkhorn distance using theSinkhorn-Knopp algorithmInputs:- M : cost matrix (n x m)- r : vector of marginals (n, )- c : vector of marginals (m, )- lam : strength of the entropic regularization- epsilon : convergence parameterOutputs:- P : optimal transport matrix (n x m)- dist : Sinkhorn distance"""r = np.array([3, 3, 3, 4, 2, 2, 2, 1])c = np.array([4, 2, 6, 4, 4])M = np.array([[2, 2, 1, 0, 0], [0, -2, -2, -2, -2], [1, 2, 2, 2, -1], [2, 1, 0, 1, -1],[0.5, 2, 2, 1, 0], [0, 1, 1, 1, -1], [-2, 2, 2, 1, 1], [2, 1, 2, 1, -1]],dtype=float) M = -M # 将M变号,从偏好转为代价n, m = M.shape  # 8, 5P = np.exp(-lam * M) # (8, 5)P /= P.sum()  # 归一化u = np.zeros(n) # (8, )# normalize this matrixwhile np.max(np.abs(u - P.sum(1))) > eplison: # 这里是用行和判断收敛# 对行和列进行缩放,使用到了numpy的广播机制,不了解广播机制的同学可以去百度一下u = P.sum(1) # 行和 (8, )P *= (r / u).reshape((-1, 1)) # 缩放行元素,使行和逼近rv = P.sum(0) # 列和 (5, )P *= (c / v).reshape((1, -1)) # 缩放列元素,使列和逼近creturn P, np.sum(P * M) # 返回分配矩阵和Sinkhorn距离lam_list=[1,5,10,20,30,40,50,60,70,80,90,100,110,120,130,140,150]cost_list=[]
for lam in lam_list:P, d = compute_optimal_transport(lam=lam)cost_list.append(d)
print(cost_list)
plt.plot(np.array(lam_list),np.array(cost_list),c="g")
plt.show()## 现在这个地方也有的

在这里插入图片描述
这个地方其实有一个画图的小问题,我待会要再写一下

可以看到大概是在lam =150的时候,就已经不稳定了,所以这个例子的问题的解的最小花费约等于24,但是我发现一个更有意思的问题,就是这个分配矩阵是唯一的吗,很显然不是的, 利用我上篇文章学到的线性规划,我发现matlab和python找到的是两个不同的解,

线性规划

matlab

clc;
clear;r = [3, 3, 3, 4, 2, 2, 2, 1];
c = [4, 2, 6, 4, 4];
cost_matrix =  [2, 2, 1, 0, 0;0, -2, -2, -2, -2; 1, 2, 2, 2, -1;2, 1, 0, 1, -1;0.5, 2, 2, 1, 0;0, 1, 1, 1, -1;-2, 2, 2, 1, 1;2, 1, 2, 1, -1];cost_matrix_t = (-1)*transpose(cost_matrix);% 需要有符号
cost_vec = cost_matrix_t(:);raw_equ = zeros(8,40);
for i =1:8raw_equ(i,((i-1)*5+1):((i-1)*5+5))=1;
endcol_equ = zeros(5,40);
for i =1:5for j =1:8col_equ(i,i+(j-1)*5)=1;end
endequ = [raw_equ;col_equ];
equ_value = horzcat(r, c);
% x1,x2,x3,x4,x5
% x6,x7,x8,x9,x10
% x11,x12,x13,x14,x15
% x16,x17,x18,x19,x20
% x21,x22,x23,x24,x25
% x26,x27,x28,x29,x30
% x31,x32,x33,x34,x35
% x36,x37,x38,x39,x40% 现在我要求的变量是这样的,
f=cost_vec;			% 价值向量
a=[];	% a、b对应不等式的左边和右边
b=[];
aeq=equ;	% aeq和beq对应等式的左边和右边
beq=equ_value;
[x,y]=linprog(f,a,b,aeq,beq,zeros(40,1));arr_mat = transpose(reshape(x',5,8));

结果如下
在这里插入图片描述
分配矩阵如下在这里插入图片描述

python代码

# Define parameters
m = 8
n = 5p = np.array([3, 3, 3, 4, 2, 2, 2, 1])
q = np.array([4, 2, 6, 4, 4])C = -1*np.array([[2, 2, 1, 0, 0], [0, -2, -2, -2, -2], [1, 2, 2, 2, -1], [2, 1, 0, 1, -1],[0.5, 2, 2, 1, 0], [0, 1, 1, 1, -1], [-2, 2, 2, 1, 1], [2, 1, 2, 1, -1]],dtype=float)# Vectorize matrix C
C_vec = C.reshape((m*n, 1), order='F')# Construct matrix A by Kronecker product
A1 = np.kron(np.ones((1, n)), np.identity(m))
A2 = np.kron(np.identity(n), np.ones((1, m)))
A = np.vstack([A1, A2])# Construct vector b
b = np.hstack([p, q])# Solve the primal problem
res = linprog(C_vec, A_eq=A, b_eq=b)# Print results
print("message:", res.message)
print("nit:", res.nit)
print("fun:", res.fun)
print("z:", res.x)
print("X:", res.x.reshape((m,n), order='F'))

结果如下
在这里插入图片描述
可以看到花费都是24,但是两者的分配矩阵并不一样哈

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/243195.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用ffmpeg调整视频中音频采样率及声道

1 原始视频信息 通过ffmpeg -i命令查看视频基本信息 ffmpeg -i example2.mp4 ffmpeg version 6.1-essentials_build-www.gyan.dev Copyright (c) 2000-2023 the FFmpeg developersbuilt with gcc 12.2.0 (Rev10, Built by MSYS2 project)configuration: --enable-gpl --enable…

网页无法访问但是有网什么原因

目录 1.运行网络诊断,确认原因 原因A.远程计算机或设备将不接受连接(该设备或资源(Web 代理)未设置为接受端口“7890”上的连接 原因B.DNS服务器未响应 场景A.其他的浏览器可以打开网页,自带的Edge却不行 方法A:关闭代理 Google自带翻译…

C++中命名空间、缺省参数、函数重载

目录 1.命名空间 2.缺省参数 3.函数重载 1.命名空间 在C中定义命名空间我们需要用到namespace关键字,后面跟上命名空间的名字,结构框架有点类似结构体(如图所示) 上面的代码我一一进行讲解: 1.我们先来说第三行和main函…

如何搭建MariaDB并实现无公网ip环境远程连接本地数据库

🌈个人主页: Aileen_0v0 🔥热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法 ​💫个人格言:“没有罗马,那就自己创造罗马~” 文章目录 1. 配置MariaDB数据库1.1 安装MariaDB数据库1.2 测试局域网内远程连接 2. 内网穿透2.1 创建隧道映射…

【C++】priority_queue模拟实现过程中值得注意的点

👀樊梓慕:个人主页 🎥个人专栏:《C语言》《数据结构》《蓝桥杯试题》《LeetCode刷题笔记》《实训项目》《C》《Linux》《算法》 🌝每一个不曾起舞的日子,都是对生命的辜负 前言 本篇文章旨在记录博主在模…

HCIA vlan练习

目录 实验拓扑 实验要求 实验步骤 1、交换机创建vlan 2、交换机上的各个接口划分到对应vlan中 3、trunk干道 4、路由器单臂路由 5、路由器DHCP设置 实验测试 华为交换机更换端口连接模式报错处理 实验拓扑 实验要求 根据图划分vlan,并通过DHCP给主机下发…

Docker(三)使用 Docker 镜像:从仓库获取镜像;管理本地主机上的镜像;介绍镜像实现的基本原理

作者主页: 正函数的个人主页 文章收录专栏: Docker 欢迎大家点赞 👍 收藏 ⭐ 加关注哦! 使用 Docker 镜像 在之前的介绍中,我们知道镜像是 Docker 的三大组件之一。 Docker 运行容器前需要本地存在对应的镜像&#x…

c语言冒泡排序

系列文章目录 c语言冒泡排序 c语言冒泡排序 系列文章目录一、冒泡排序原理二、冒泡排序案例 一、冒泡排序原理 有几个数就需要排序几次-1 从数组第一个元素开始和相邻的元素比对,大的元素放在后面,小的放在前面 如,428057139 4与2对比&#x…

「优选算法刷题」:在排序数组中查找元素的第一个和最后个位置

一、题目 给你一个按照非递减顺序排列的整数数组 nums,和一个目标值 target。请你找出给定目标值在数组中的开始位置和结束位置。 如果数组中不存在目标值 target,返回 [-1, -1]。 你必须设计并实现时间复杂度为 O(log n) 的算法解决此问题。 示例 1&a…

Spring--@Async解析

一、Async 简介 从Spring3开始提供了Async注解,被该注解标注的方法,Spring底层会新建一个线程池或者使用已有的线程池中的线程去异步的执行被标注的方法。 二、Async 工作原理 Async与Transactional 工作原理基本是一样的,也是通过Spring …

【数据结构初阶】——顺序表

本文由睡觉待开机原创,转载请注明出处。 本内容在csdn网站首发 欢迎各位点赞—评论—收藏 如果存在不足之处请评论留言,共同进步! 这里写目录标题 1.数据结构2.顺序表线性表顺序表的结构 3.动态顺序表的实现 1.数据结构 数据结构的概念&…

【JavaEE进阶】 Spring Boot⽇志

文章目录 🎋关于日志🚩为什么要学习⽇志🚩⽇志的⽤途🚩日志的简单使用 🎄打印⽇志🚩程序中得到⽇志对象🚩使⽤⽇志对象打印⽇志 🎍⽇志格式的说明🚩⽇志级别的作用&#…

QQ数据包解密

Windows版qq数据包格式&#xff1a; android版qq数据包格式&#xff1a; 密钥&#xff1a;16个0 算法&#xff1a;tea_crypt算法 pc版qq 0825数据包解密源码&#xff1a; #include "qq.h" #include "qqcrypt.h" #include <WinSock2.h> #include…

Win10下在Qt项目中配置SQlite3环境

资源下载 官网资源&#xff1a;SQLite Download Page 1、sqlite.h sqlite-amalgamation-3450000.zip (2.60 MiB) 2、sqlite3.def&#xff0c;sqlite3.dll sqlite-dll-win-x64-3450000.zip (1.25 MiB) 3、 win10下安装sqlite3所需要文件 sqlite-tools-win-x64-3450000.zipht…

node介绍

1.node是什么 Node是一个基于Chrome V8引擎的JS运行环境。 Node不是一个独立的语言、node不是JS框架。 Node是一个除了浏览器之外的、可以让JS运行的环境 Node.js是一个让JS运行在服务端的开发平台&#xff0c;是使用事件驱动&#xff0c;异步非阻塞I/O&#xff0c;单线程&…

fastJson和jackson的日期数据处理

目录 1.jackson 2.fastjson 3.总结 1.jackson jackson是spring mvc默认的JSON解析方法&#xff0c;前端的数据序列化处理之后&#xff0c;后端经过反序列化处理可以直接使用实体对象进行接收。后端接口返回实体对象&#xff0c;经过序列化处理后前端可以接收并进行处理。 …

回归预测 | Matlab基于ABC-SVR人工蜂群算法优化支持向量机的数据多输入单输出回归预测

回归预测 | Matlab基于ABC-SVR人工蜂群算法优化支持向量机的数据多输入单输出回归预测 目录 回归预测 | Matlab基于ABC-SVR人工蜂群算法优化支持向量机的数据多输入单输出回归预测预测效果基本描述程序设计参考资料 预测效果 基本描述 1.Matlab基于ABC-SVR人工蜂群算法优化支持…

C++提高编程---模板---类模板

目录 一、类模板 1.模板 2.类模板的作用 3.语法 4.声明 二、类模板和函数模板的区别 三、类模板中成员函数的创建时机 四、类模板对象做函数参数 五、类模板与继承 六、类模板成员函数类外实现 七、类模板分文件编写 八、类模板与友元 九、类模板案例 一、类模板 …

第14章_集合与数据结构拓展练习(前序、中序、后序遍历,线性结构,单向链表构建,单向链表及其反转,字符串压缩)

文章目录 第14章_集合与数据结构拓展练习选择填空题1、前序、中序、后序遍历2、线性结构3、其它 编程题4、单向链表构建5、单向链表及其反转6、字符串压缩 第14章_集合与数据结构拓展练习 选择填空题 1、前序、中序、后序遍历 分析&#xff1a; 完全二叉树&#xff1a; 叶结点…

ElasticSearch的常用增删改查DSL和代码

es增删改查常用语法 我们日常开发中&#xff0c;操作数据库写sql倒是不可能忘记&#xff0c;但是操作es的dsl语句有时候很容易忘记&#xff0c;特地记录一下方便查找。 DSL语句 1、创建索引 -- 创建索引 PUT /my_index {"mappings": {"properties": {&…